Spaces:
Runtime error
Runtime error
zhzluke96
commited on
Commit
·
d5d0921
1
Parent(s):
2be0618
update
Browse files- data/speakers/Bob_ft10.pt +3 -0
- modules/ChatTTS/ChatTTS/core.py +1 -1
- modules/SynthesizeSegments.py +40 -7
- modules/api/app_config.py +2 -2
- modules/api/impl/google_api.py +66 -107
- modules/api/impl/handler/AudioHandler.py +37 -0
- modules/api/impl/handler/SSMLHandler.py +94 -0
- modules/api/impl/handler/TTSHandler.py +97 -0
- modules/api/impl/model/audio_model.py +14 -0
- modules/api/impl/model/chattts_model.py +19 -0
- modules/api/impl/model/enhancer_model.py +11 -0
- modules/api/impl/openai_api.py +57 -56
- modules/api/impl/refiner_api.py +1 -0
- modules/api/impl/ssml_api.py +30 -25
- modules/api/impl/tts_api.py +58 -31
- modules/api/impl/xtts_v2_api.py +52 -6
- modules/api/utils.py +2 -11
- modules/devices/devices.py +7 -1
- modules/finetune/train_speaker.py +18 -11
- modules/prompts/news_oral_prompt.txt +14 -0
- modules/prompts/podcast_prompt.txt +1 -0
- modules/ssml_parser/SSMLParser.py +1 -4
- modules/webui/speaker/speaker_editor.py +1 -1
- modules/webui/speaker/speaker_merger.py +2 -6
data/speakers/Bob_ft10.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91015b82a99c40034048090228b6d647ab99fd7b86e8babd6a7c3a9236e8d800
|
| 3 |
+
size 4508
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
|
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
|
|
| 17 |
|
| 18 |
from huggingface_hub import snapshot_download
|
| 19 |
|
| 20 |
-
logging.basicConfig(level=logging.
|
| 21 |
|
| 22 |
|
| 23 |
class Chat:
|
|
|
|
| 17 |
|
| 18 |
from huggingface_hub import snapshot_download
|
| 19 |
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
|
| 22 |
|
| 23 |
class Chat:
|
modules/SynthesizeSegments.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import copy
|
|
|
|
| 2 |
from box import Box
|
| 3 |
from pydub import AudioSegment
|
| 4 |
from typing import List, Union
|
|
@@ -160,7 +161,21 @@ class SynthesizeSegments:
|
|
| 160 |
for i in range(0, len(bucket), self.batch_size):
|
| 161 |
batch = bucket[i : i + self.batch_size]
|
| 162 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
params = param_arr[0]
|
| 166 |
audio_datas = generate_audio.generate_audio_batch(
|
|
@@ -182,6 +197,7 @@ class SynthesizeSegments:
|
|
| 182 |
|
| 183 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
| 184 |
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
|
|
|
| 185 |
original_index = src_segments.index(segment)
|
| 186 |
audio_segments[original_index] = audio_segment
|
| 187 |
|
|
@@ -226,13 +242,30 @@ class SynthesizeSegments:
|
|
| 226 |
|
| 227 |
sentences = spliter.parse(text)
|
| 228 |
for sentence in sentences:
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
params=copy.copy(segment.params),
|
| 234 |
-
)
|
| 235 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
return ret_segments
|
| 238 |
|
|
|
|
| 1 |
import copy
|
| 2 |
+
import re
|
| 3 |
from box import Box
|
| 4 |
from pydub import AudioSegment
|
| 5 |
from typing import List, Union
|
|
|
|
| 161 |
for i in range(0, len(bucket), self.batch_size):
|
| 162 |
batch = bucket[i : i + self.batch_size]
|
| 163 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
| 164 |
+
|
| 165 |
+
def append_eos(text: str):
|
| 166 |
+
text = text.strip()
|
| 167 |
+
eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"]
|
| 168 |
+
has_eos = False
|
| 169 |
+
for eos in eos_arr:
|
| 170 |
+
if eos in text:
|
| 171 |
+
has_eos = True
|
| 172 |
+
break
|
| 173 |
+
if not has_eos:
|
| 174 |
+
text += self.eos
|
| 175 |
+
return text
|
| 176 |
+
|
| 177 |
+
# 这里会添加 end_of_text 到 text 之后
|
| 178 |
+
texts = [append_eos(params.text) for params in param_arr]
|
| 179 |
|
| 180 |
params = param_arr[0]
|
| 181 |
audio_datas = generate_audio.generate_audio_batch(
|
|
|
|
| 197 |
|
| 198 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
| 199 |
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
| 200 |
+
# compare by Box object
|
| 201 |
original_index = src_segments.index(segment)
|
| 202 |
audio_segments[original_index] = audio_segment
|
| 203 |
|
|
|
|
| 242 |
|
| 243 |
sentences = spliter.parse(text)
|
| 244 |
for sentence in sentences:
|
| 245 |
+
seg = SSMLSegment(
|
| 246 |
+
text=sentence,
|
| 247 |
+
attrs=segment.attrs.copy(),
|
| 248 |
+
params=copy.copy(segment.params),
|
|
|
|
|
|
|
| 249 |
)
|
| 250 |
+
ret_segments.append(seg)
|
| 251 |
+
setattr(seg, "_idx", len(ret_segments) - 1)
|
| 252 |
+
|
| 253 |
+
def is_none_speak_segment(segment: SSMLSegment):
|
| 254 |
+
text = segment.text.strip()
|
| 255 |
+
regexp = r"\[[^\]]+?\]"
|
| 256 |
+
text = re.sub(regexp, "", text)
|
| 257 |
+
text = text.strip()
|
| 258 |
+
if not text:
|
| 259 |
+
return True
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
# 将 none_speak 合并到前一个 speak segment
|
| 263 |
+
for i in range(1, len(ret_segments)):
|
| 264 |
+
if is_none_speak_segment(ret_segments[i]):
|
| 265 |
+
ret_segments[i - 1].text += ret_segments[i].text
|
| 266 |
+
ret_segments[i].text = ""
|
| 267 |
+
# 移除空的 segment
|
| 268 |
+
ret_segments = [seg for seg in ret_segments if seg.text.strip()]
|
| 269 |
|
| 270 |
return ret_segments
|
| 271 |
|
modules/api/app_config.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
app_description = """
|
| 2 |
-
ChatTTS-Forge
|
| 3 |
-
ChatTTS-Forge is a
|
| 4 |
|
| 5 |
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
| 6 |
|
|
|
|
| 1 |
app_description = """
|
| 2 |
+
🍦 ChatTTS-Forge 是一个围绕 TTS 生成模型 ChatTTS 开发的项目,实现了 API Server 和 基于 Gradio 的 WebUI。<br/>
|
| 3 |
+
🍦 ChatTTS-Forge is a project developed around the TTS generation model ChatTTS, implementing an API Server and a Gradio-based WebUI.
|
| 4 |
|
| 5 |
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
| 6 |
|
modules/api/impl/google_api.py
CHANGED
|
@@ -1,38 +1,25 @@
|
|
| 1 |
-
import
|
| 2 |
-
from typing import Literal
|
| 3 |
from fastapi import HTTPException
|
| 4 |
|
| 5 |
-
import io
|
| 6 |
-
import soundfile as sf
|
| 7 |
from pydantic import BaseModel
|
| 8 |
|
| 9 |
|
| 10 |
-
from modules.Enhancer.ResembleEnhance import (
|
| 11 |
-
apply_audio_enhance,
|
| 12 |
-
apply_audio_enhance_full,
|
| 13 |
-
)
|
| 14 |
from modules.api.Api import APIManager
|
| 15 |
-
from modules.
|
| 16 |
-
from modules.
|
| 17 |
-
from modules.
|
| 18 |
-
from modules.
|
|
|
|
| 19 |
|
| 20 |
-
from modules import
|
| 21 |
-
from modules.speaker import speaker_mgr
|
| 22 |
|
| 23 |
|
| 24 |
-
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
| 25 |
-
from modules.SynthesizeSegments import (
|
| 26 |
-
SynthesizeSegments,
|
| 27 |
-
combine_audio_segments,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
from modules.api import utils as api_utils
|
| 31 |
|
| 32 |
|
| 33 |
class SynthesisInput(BaseModel):
|
| 34 |
-
text: str =
|
| 35 |
-
ssml: str =
|
| 36 |
|
| 37 |
|
| 38 |
class VoiceSelectionParams(BaseModel):
|
|
@@ -50,24 +37,15 @@ class VoiceSelectionParams(BaseModel):
|
|
| 50 |
|
| 51 |
|
| 52 |
class AudioConfig(BaseModel):
|
| 53 |
-
audioEncoding:
|
| 54 |
speakingRate: float = 1
|
| 55 |
pitch: float = 0
|
| 56 |
volumeGainDb: float = 0
|
| 57 |
sampleRateHertz: int = 24000
|
| 58 |
-
batchSize: int =
|
| 59 |
spliterThreshold: int = 100
|
| 60 |
|
| 61 |
|
| 62 |
-
class EnhancerConfig(BaseModel):
|
| 63 |
-
enabled: bool = False
|
| 64 |
-
model: str = "resemble-enhance"
|
| 65 |
-
nfe: int = 32
|
| 66 |
-
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
|
| 67 |
-
lambd: float = 0.5
|
| 68 |
-
tau: float = 0.5
|
| 69 |
-
|
| 70 |
-
|
| 71 |
class GoogleTextSynthesizeRequest(BaseModel):
|
| 72 |
input: SynthesisInput
|
| 73 |
voice: VoiceSelectionParams
|
|
@@ -92,7 +70,11 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
| 92 |
voice_name = voice.name
|
| 93 |
infer_seed = voice.seed or 42
|
| 94 |
eos = voice.eos or "[uv_break]"
|
| 95 |
-
audio_format = audioConfig.audioEncoding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
speaking_rate = audioConfig.speakingRate or 1
|
| 97 |
pitch = audioConfig.pitch or 0
|
| 98 |
volume_gain_db = audioConfig.volumeGainDb or 0
|
|
@@ -101,6 +83,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
| 101 |
|
| 102 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
| 103 |
|
|
|
|
| 104 |
sample_rate = audioConfig.sampleRateHertz or 24000
|
| 105 |
|
| 106 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
|
@@ -111,92 +94,68 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
| 111 |
status_code=422, detail="The specified voice name is not supported."
|
| 112 |
)
|
| 113 |
|
| 114 |
-
if
|
| 115 |
raise HTTPException(
|
| 116 |
-
status_code=422, detail="
|
| 117 |
)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
try:
|
| 124 |
if input.text:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
|
| 135 |
-
top_K=voice.topK if voice.topK else params.get("top_k", 20),
|
| 136 |
-
spk=params.get("spk", -1),
|
| 137 |
-
infer_seed=infer_seed,
|
| 138 |
-
prompt1=params.get("prompt1", ""),
|
| 139 |
-
prompt2=params.get("prompt2", ""),
|
| 140 |
-
prefix=params.get("prefix", ""),
|
| 141 |
-
batch_size=batch_size,
|
| 142 |
-
spliter_threshold=spliter_threshold,
|
| 143 |
-
end_of_sentence=eos,
|
| 144 |
)
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
segments = parser.parse(input.ssml)
|
| 149 |
-
for seg in segments:
|
| 150 |
-
seg["text"] = text_normalize(seg["text"], is_end=True)
|
| 151 |
-
|
| 152 |
-
if len(segments) == 0:
|
| 153 |
-
raise HTTPException(
|
| 154 |
-
status_code=422, detail="The SSML text is empty or parsing failed."
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
synthesize = SynthesizeSegments(
|
| 158 |
-
batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
|
| 159 |
-
)
|
| 160 |
-
audio_segments = synthesize.synthesize_segments(segments)
|
| 161 |
-
combined_audio = combine_audio_segments(audio_segments)
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
raise HTTPException(
|
| 166 |
-
status_code=422, detail="Either text or SSML input must be provided."
|
| 167 |
-
)
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
solver=enhancerConfig.solver,
|
| 175 |
-
lambd=enhancerConfig.lambd,
|
| 176 |
-
tau=enhancerConfig.tau,
|
| 177 |
)
|
| 178 |
|
| 179 |
-
|
| 180 |
-
audio_data,
|
| 181 |
-
rate=speaking_rate,
|
| 182 |
-
pitch=pitch,
|
| 183 |
-
volume=volume_gain_db,
|
| 184 |
-
sr=sample_rate,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
buffer = io.BytesIO()
|
| 188 |
-
sf.write(buffer, audio_data, sample_rate, format="wav")
|
| 189 |
-
buffer.seek(0)
|
| 190 |
|
| 191 |
-
|
| 192 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
"audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
|
| 199 |
-
}
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
import logging
|
|
|
|
| 1 |
+
from typing import Union
|
|
|
|
| 2 |
from fastapi import HTTPException
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from modules.api.Api import APIManager
|
| 8 |
+
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 9 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 10 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 11 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 12 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 13 |
|
| 14 |
+
from modules.speaker import Speaker, speaker_mgr
|
|
|
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from modules.api import utils as api_utils
|
| 18 |
|
| 19 |
|
| 20 |
class SynthesisInput(BaseModel):
|
| 21 |
+
text: Union[str, None] = None
|
| 22 |
+
ssml: Union[str, None] = None
|
| 23 |
|
| 24 |
|
| 25 |
class VoiceSelectionParams(BaseModel):
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class AudioConfig(BaseModel):
|
| 40 |
+
audioEncoding: AudioFormat = AudioFormat.mp3
|
| 41 |
speakingRate: float = 1
|
| 42 |
pitch: float = 0
|
| 43 |
volumeGainDb: float = 0
|
| 44 |
sampleRateHertz: int = 24000
|
| 45 |
+
batchSize: int = 4
|
| 46 |
spliterThreshold: int = 100
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
class GoogleTextSynthesizeRequest(BaseModel):
|
| 50 |
input: SynthesisInput
|
| 51 |
voice: VoiceSelectionParams
|
|
|
|
| 70 |
voice_name = voice.name
|
| 71 |
infer_seed = voice.seed or 42
|
| 72 |
eos = voice.eos or "[uv_break]"
|
| 73 |
+
audio_format = audioConfig.audioEncoding
|
| 74 |
+
|
| 75 |
+
if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
|
| 76 |
+
audio_format = AudioFormat(audio_format)
|
| 77 |
+
|
| 78 |
speaking_rate = audioConfig.speakingRate or 1
|
| 79 |
pitch = audioConfig.pitch or 0
|
| 80 |
volume_gain_db = audioConfig.volumeGainDb or 0
|
|
|
|
| 83 |
|
| 84 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
| 85 |
|
| 86 |
+
# TODO
|
| 87 |
sample_rate = audioConfig.sampleRateHertz or 24000
|
| 88 |
|
| 89 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
|
|
|
| 94 |
status_code=422, detail="The specified voice name is not supported."
|
| 95 |
)
|
| 96 |
|
| 97 |
+
if not isinstance(params.get("spk"), Speaker):
|
| 98 |
raise HTTPException(
|
| 99 |
+
status_code=422, detail="The specified voice name is not supported."
|
| 100 |
)
|
| 101 |
|
| 102 |
+
speaker = params.get("spk")
|
| 103 |
+
tts_config = ChatTTSConfig(
|
| 104 |
+
style=params.get("style", ""),
|
| 105 |
+
temperature=voice.temperature,
|
| 106 |
+
top_k=voice.topK,
|
| 107 |
+
top_p=voice.topP,
|
| 108 |
+
)
|
| 109 |
+
infer_config = InferConfig(
|
| 110 |
+
batch_size=batch_size,
|
| 111 |
+
spliter_threshold=spliter_threshold,
|
| 112 |
+
eos=eos,
|
| 113 |
+
seed=infer_seed,
|
| 114 |
+
)
|
| 115 |
+
adjust_config = AdjustConfig(
|
| 116 |
+
speaking_rate=speaking_rate,
|
| 117 |
+
pitch=pitch,
|
| 118 |
+
volume_gain_db=volume_gain_db,
|
| 119 |
+
)
|
| 120 |
+
enhancer_config = enhancerConfig
|
| 121 |
+
|
| 122 |
+
mime_type = f"audio/{audio_format.value}"
|
| 123 |
+
if audio_format == AudioFormat.mp3:
|
| 124 |
+
mime_type = "audio/mpeg"
|
| 125 |
try:
|
| 126 |
if input.text:
|
| 127 |
+
text_content = input.text
|
| 128 |
+
|
| 129 |
+
handler = TTSHandler(
|
| 130 |
+
text_content=text_content,
|
| 131 |
+
spk=speaker,
|
| 132 |
+
tts_config=tts_config,
|
| 133 |
+
infer_config=infer_config,
|
| 134 |
+
adjust_config=adjust_config,
|
| 135 |
+
enhancer_config=enhancer_config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
+
base64_string = handler.enqueue_to_base64(format=audio_format)
|
| 139 |
+
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
elif input.ssml:
|
| 142 |
+
ssml_content = input.ssml
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
handler = SSMLHandler(
|
| 145 |
+
ssml_content=ssml_content,
|
| 146 |
+
infer_config=infer_config,
|
| 147 |
+
adjust_config=adjust_config,
|
| 148 |
+
enhancer_config=enhancer_config,
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
+
base64_string = handler.enqueue_to_base64(format=audio_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
|
|
|
|
| 154 |
|
| 155 |
+
else:
|
| 156 |
+
raise HTTPException(
|
| 157 |
+
status_code=422, detail="Invalid input text or ssml specified."
|
| 158 |
+
)
|
|
|
|
|
|
|
| 159 |
|
| 160 |
except Exception as e:
|
| 161 |
import logging
|
modules/api/impl/handler/AudioHandler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import numpy as np
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
|
| 6 |
+
from modules.api.impl.model.audio_model import AudioFormat
|
| 7 |
+
from modules.api import utils as api_utils
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AudioHandler:
|
| 11 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
|
| 15 |
+
audio_data, sample_rate = self.enqueue()
|
| 16 |
+
|
| 17 |
+
buffer = io.BytesIO()
|
| 18 |
+
sf.write(buffer, audio_data, sample_rate, format="wav")
|
| 19 |
+
buffer.seek(0)
|
| 20 |
+
|
| 21 |
+
if format == AudioFormat.mp3:
|
| 22 |
+
buffer = api_utils.wav_to_mp3(buffer)
|
| 23 |
+
|
| 24 |
+
return buffer
|
| 25 |
+
|
| 26 |
+
def enqueue_to_bytes(self, format: AudioFormat) -> bytes:
|
| 27 |
+
buffer = self.enqueue_to_buffer(format=format)
|
| 28 |
+
binary = buffer.read()
|
| 29 |
+
return binary
|
| 30 |
+
|
| 31 |
+
def enqueue_to_base64(self, format: AudioFormat) -> str:
|
| 32 |
+
binary = self.enqueue_to_bytes(format=format)
|
| 33 |
+
|
| 34 |
+
base64_encoded = base64.b64encode(binary)
|
| 35 |
+
base64_string = base64_encoded.decode("utf-8")
|
| 36 |
+
|
| 37 |
+
return base64_string
|
modules/api/impl/handler/SSMLHandler.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
| 5 |
+
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
| 6 |
+
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 7 |
+
from modules.api.impl.model.audio_model import AdjustConfig
|
| 8 |
+
from modules.api.impl.model.chattts_model import InferConfig
|
| 9 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 10 |
+
from modules.normalization import text_normalize
|
| 11 |
+
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
| 12 |
+
from modules.utils import audio
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SSMLHandler(AudioHandler):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
ssml_content: str,
|
| 19 |
+
infer_config: InferConfig,
|
| 20 |
+
adjust_config: AdjustConfig,
|
| 21 |
+
enhancer_config: EnhancerConfig,
|
| 22 |
+
) -> None:
|
| 23 |
+
assert isinstance(ssml_content, str), "ssml_content must be a string."
|
| 24 |
+
assert isinstance(
|
| 25 |
+
infer_config, InferConfig
|
| 26 |
+
), "infer_config must be an InferConfig object."
|
| 27 |
+
assert isinstance(
|
| 28 |
+
adjust_config, AdjustConfig
|
| 29 |
+
), "adjest_config should be AdjustConfig"
|
| 30 |
+
assert isinstance(
|
| 31 |
+
enhancer_config, EnhancerConfig
|
| 32 |
+
), "enhancer_config must be an EnhancerConfig object."
|
| 33 |
+
|
| 34 |
+
self.ssml_content = ssml_content
|
| 35 |
+
self.infer_config = infer_config
|
| 36 |
+
self.adjest_config = adjust_config
|
| 37 |
+
self.enhancer_config = enhancer_config
|
| 38 |
+
|
| 39 |
+
self.validate()
|
| 40 |
+
|
| 41 |
+
def validate(self):
|
| 42 |
+
# TODO params checker
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
| 46 |
+
ssml_content = self.ssml_content
|
| 47 |
+
infer_config = self.infer_config
|
| 48 |
+
adjust_config = self.adjest_config
|
| 49 |
+
enhancer_config = self.enhancer_config
|
| 50 |
+
|
| 51 |
+
parser = create_ssml_parser()
|
| 52 |
+
segments = parser.parse(ssml_content)
|
| 53 |
+
for seg in segments:
|
| 54 |
+
seg["text"] = text_normalize(seg["text"], is_end=True)
|
| 55 |
+
|
| 56 |
+
if len(segments) == 0:
|
| 57 |
+
raise HTTPException(
|
| 58 |
+
status_code=422, detail="The SSML text is empty or parsing failed."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
synthesize = SynthesizeSegments(
|
| 62 |
+
batch_size=infer_config.batch_size,
|
| 63 |
+
eos=infer_config.eos,
|
| 64 |
+
spliter_thr=infer_config.spliter_threshold,
|
| 65 |
+
)
|
| 66 |
+
audio_segments = synthesize.synthesize_segments(segments)
|
| 67 |
+
combined_audio = combine_audio_segments(audio_segments)
|
| 68 |
+
|
| 69 |
+
sample_rate, audio_data = audio.pydub_to_np(combined_audio)
|
| 70 |
+
|
| 71 |
+
if enhancer_config.enabled:
|
| 72 |
+
nfe = enhancer_config.nfe
|
| 73 |
+
solver = enhancer_config.solver
|
| 74 |
+
lambd = enhancer_config.lambd
|
| 75 |
+
tau = enhancer_config.tau
|
| 76 |
+
|
| 77 |
+
audio_data, sample_rate = apply_audio_enhance_full(
|
| 78 |
+
audio_data=audio_data,
|
| 79 |
+
sr=sample_rate,
|
| 80 |
+
nfe=nfe,
|
| 81 |
+
solver=solver,
|
| 82 |
+
lambd=lambd,
|
| 83 |
+
tau=tau,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
audio_data = audio.apply_prosody_to_audio_data(
|
| 87 |
+
audio_data=audio_data,
|
| 88 |
+
rate=adjust_config.speed_rate,
|
| 89 |
+
pitch=adjust_config.pitch,
|
| 90 |
+
volume=adjust_config.volume_gain_db,
|
| 91 |
+
sr=sample_rate,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return audio_data, sample_rate
|
modules/api/impl/handler/TTSHandler.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
| 3 |
+
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 4 |
+
from modules.api.impl.model.audio_model import AdjustConfig
|
| 5 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 6 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 7 |
+
from modules.normalization import text_normalize
|
| 8 |
+
from modules.speaker import Speaker
|
| 9 |
+
from modules.synthesize_audio import synthesize_audio
|
| 10 |
+
|
| 11 |
+
from modules.utils.audio import apply_prosody_to_audio_data
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TTSHandler(AudioHandler):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
text_content: str,
|
| 18 |
+
spk: Speaker,
|
| 19 |
+
tts_config: ChatTTSConfig,
|
| 20 |
+
infer_config: InferConfig,
|
| 21 |
+
adjust_config: AdjustConfig,
|
| 22 |
+
enhancer_config: EnhancerConfig,
|
| 23 |
+
):
|
| 24 |
+
assert isinstance(text_content, str), "text_content should be str"
|
| 25 |
+
assert isinstance(spk, Speaker), "spk should be Speaker"
|
| 26 |
+
assert isinstance(
|
| 27 |
+
tts_config, ChatTTSConfig
|
| 28 |
+
), "tts_config should be ChatTTSConfig"
|
| 29 |
+
assert isinstance(
|
| 30 |
+
infer_config, InferConfig
|
| 31 |
+
), "infer_config should be InferConfig"
|
| 32 |
+
assert isinstance(
|
| 33 |
+
adjust_config, AdjustConfig
|
| 34 |
+
), "adjest_config should be AdjustConfig"
|
| 35 |
+
assert isinstance(
|
| 36 |
+
enhancer_config, EnhancerConfig
|
| 37 |
+
), "enhancer_config should be EnhancerConfig"
|
| 38 |
+
|
| 39 |
+
self.text_content = text_content
|
| 40 |
+
self.spk = spk
|
| 41 |
+
self.tts_config = tts_config
|
| 42 |
+
self.infer_config = infer_config
|
| 43 |
+
self.adjest_config = adjust_config
|
| 44 |
+
self.enhancer_config = enhancer_config
|
| 45 |
+
|
| 46 |
+
self.validate()
|
| 47 |
+
|
| 48 |
+
def validate(self):
|
| 49 |
+
# TODO params checker
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
| 53 |
+
text = text_normalize(self.text_content)
|
| 54 |
+
tts_config = self.tts_config
|
| 55 |
+
infer_config = self.infer_config
|
| 56 |
+
adjust_config = self.adjest_config
|
| 57 |
+
enhancer_config = self.enhancer_config
|
| 58 |
+
|
| 59 |
+
sample_rate, audio_data = synthesize_audio(
|
| 60 |
+
text,
|
| 61 |
+
spk=self.spk,
|
| 62 |
+
temperature=tts_config.temperature,
|
| 63 |
+
top_P=tts_config.top_p,
|
| 64 |
+
top_K=tts_config.top_k,
|
| 65 |
+
prompt1=tts_config.prompt1,
|
| 66 |
+
prompt2=tts_config.prompt2,
|
| 67 |
+
prefix=tts_config.prefix,
|
| 68 |
+
infer_seed=infer_config.seed,
|
| 69 |
+
batch_size=infer_config.batch_size,
|
| 70 |
+
spliter_threshold=infer_config.spliter_threshold,
|
| 71 |
+
end_of_sentence=infer_config.eos,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if enhancer_config.enabled:
|
| 75 |
+
nfe = enhancer_config.nfe
|
| 76 |
+
solver = enhancer_config.solver
|
| 77 |
+
lambd = enhancer_config.lambd
|
| 78 |
+
tau = enhancer_config.tau
|
| 79 |
+
|
| 80 |
+
audio_data, sample_rate = apply_audio_enhance_full(
|
| 81 |
+
audio_data=audio_data,
|
| 82 |
+
sr=sample_rate,
|
| 83 |
+
nfe=nfe,
|
| 84 |
+
solver=solver,
|
| 85 |
+
lambd=lambd,
|
| 86 |
+
tau=tau,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
audio_data = apply_prosody_to_audio_data(
|
| 90 |
+
audio_data=audio_data,
|
| 91 |
+
rate=adjust_config.speed_rate,
|
| 92 |
+
pitch=adjust_config.pitch,
|
| 93 |
+
volume=adjust_config.volume_gain_db,
|
| 94 |
+
sr=sample_rate,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return audio_data, sample_rate
|
modules/api/impl/model/audio_model.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AudioFormat(str, Enum):
|
| 7 |
+
mp3 = "mp3"
|
| 8 |
+
wav = "wav"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AdjustConfig(BaseModel):
|
| 12 |
+
pitch: float = 0
|
| 13 |
+
speed_rate: float = 1
|
| 14 |
+
volume_gain_db: float = 0
|
modules/api/impl/model/chattts_model.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ChatTTSConfig(BaseModel):
|
| 5 |
+
style: str = ""
|
| 6 |
+
temperature: float = 0.3
|
| 7 |
+
top_p: float = 0.7
|
| 8 |
+
top_k: int = 20
|
| 9 |
+
prompt1: str = ""
|
| 10 |
+
prompt2: str = ""
|
| 11 |
+
prefix: str = ""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class InferConfig(BaseModel):
|
| 15 |
+
batch_size: int = 4
|
| 16 |
+
spliter_threshold: int = 100
|
| 17 |
+
# end_of_sentence
|
| 18 |
+
eos: str = "[uv_break]"
|
| 19 |
+
seed: int = 42
|
modules/api/impl/model/enhancer_model.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EnhancerConfig(BaseModel):
|
| 6 |
+
enabled: bool = False
|
| 7 |
+
model: str = "resemble-enhance"
|
| 8 |
+
nfe: int = 32
|
| 9 |
+
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
|
| 10 |
+
lambd: float = 0.5
|
| 11 |
+
tau: float = 0.5
|
modules/api/impl/openai_api.py
CHANGED
|
@@ -1,42 +1,38 @@
|
|
| 1 |
from fastapi import File, Form, HTTPException, Body, UploadFile
|
| 2 |
-
from fastapi.responses import StreamingResponse
|
| 3 |
|
| 4 |
-
import io
|
| 5 |
from numpy import clip
|
| 6 |
-
import soundfile as sf
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
-
from fastapi.responses import
|
| 9 |
-
|
| 10 |
|
| 11 |
-
from modules.synthesize_audio import synthesize_audio
|
| 12 |
-
from modules.normalization import text_normalize
|
| 13 |
|
| 14 |
-
from modules import
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
-
from typing import List,
|
| 18 |
-
import pyrubberband as pyrb
|
| 19 |
|
| 20 |
from modules.api import utils as api_utils
|
| 21 |
from modules.api.Api import APIManager
|
| 22 |
|
| 23 |
-
from modules.speaker import speaker_mgr
|
| 24 |
from modules.data import styles_mgr
|
| 25 |
|
| 26 |
-
import numpy as np
|
| 27 |
-
|
| 28 |
|
| 29 |
class AudioSpeechRequest(BaseModel):
|
| 30 |
input: str # 需要合成的文本
|
| 31 |
model: str = "chattts-4w"
|
| 32 |
voice: str = "female2"
|
| 33 |
-
response_format:
|
| 34 |
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
|
| 35 |
seed: int = 42
|
|
|
|
| 36 |
temperature: float = 0.3
|
|
|
|
|
|
|
|
|
|
| 37 |
style: str = ""
|
| 38 |
-
# 是否开启batch合成,小于等于1表示不适用batch
|
| 39 |
-
# 开启batch合成会自动分割句子
|
| 40 |
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
|
| 41 |
spliter_threshold: float = Field(
|
| 42 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
|
@@ -44,6 +40,9 @@ class AudioSpeechRequest(BaseModel):
|
|
| 44 |
# end of sentence
|
| 45 |
eos: str = "[uv_break]"
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
async def openai_speech_api(
|
| 49 |
request: AudioSpeechRequest = Body(
|
|
@@ -55,7 +54,14 @@ async def openai_speech_api(
|
|
| 55 |
voice = request.voice
|
| 56 |
style = request.style
|
| 57 |
eos = request.eos
|
|
|
|
|
|
|
| 58 |
response_format = request.response_format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
batch_size = request.batch_size
|
| 60 |
spliter_threshold = request.spliter_threshold
|
| 61 |
speed = request.speed
|
|
@@ -71,49 +77,45 @@ async def openai_speech_api(
|
|
| 71 |
except:
|
| 72 |
raise HTTPException(status_code=400, detail="Invalid style.")
|
| 73 |
|
| 74 |
-
|
| 75 |
-
# Normalize the text
|
| 76 |
-
text = text_normalize(input_text, is_end=True)
|
| 77 |
-
|
| 78 |
-
# Calculate speaker and style based on input voice
|
| 79 |
-
params = api_utils.calc_spk_style(spk=voice, style=style)
|
| 80 |
-
|
| 81 |
-
spk = params.get("spk", -1)
|
| 82 |
-
seed = params.get("seed", request.seed or 42)
|
| 83 |
-
temperature = params.get("temperature", request.temperature or 0.3)
|
| 84 |
-
prompt1 = params.get("prompt1", "")
|
| 85 |
-
prompt2 = params.get("prompt2", "")
|
| 86 |
-
prefix = params.get("prefix", "")
|
| 87 |
-
|
| 88 |
-
# Generate audio
|
| 89 |
-
sample_rate, audio_data = synthesize_audio(
|
| 90 |
-
text,
|
| 91 |
-
temperature=temperature,
|
| 92 |
-
top_P=0.7,
|
| 93 |
-
top_K=20,
|
| 94 |
-
spk=spk,
|
| 95 |
-
infer_seed=seed,
|
| 96 |
-
batch_size=batch_size,
|
| 97 |
-
spliter_threshold=spliter_threshold,
|
| 98 |
-
prompt1=prompt1,
|
| 99 |
-
prompt2=prompt2,
|
| 100 |
-
prefix=prefix,
|
| 101 |
-
end_of_sentence=eos,
|
| 102 |
-
)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
# Convert wav to mp3
|
| 114 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
except Exception as e:
|
| 119 |
import logging
|
|
@@ -150,7 +152,6 @@ class TranscriptionsVerboseResponse(BaseModel):
|
|
| 150 |
def setup(app: APIManager):
|
| 151 |
app.post(
|
| 152 |
"/v1/audio/speech",
|
| 153 |
-
response_class=FileResponse,
|
| 154 |
description="""
|
| 155 |
openai api document:
|
| 156 |
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
|
|
|
|
| 1 |
from fastapi import File, Form, HTTPException, Body, UploadFile
|
|
|
|
| 2 |
|
|
|
|
| 3 |
from numpy import clip
|
|
|
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
+
from fastapi.responses import StreamingResponse
|
|
|
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 12 |
|
| 13 |
|
| 14 |
+
from typing import List, Optional
|
|
|
|
| 15 |
|
| 16 |
from modules.api import utils as api_utils
|
| 17 |
from modules.api.Api import APIManager
|
| 18 |
|
| 19 |
+
from modules.speaker import Speaker, speaker_mgr
|
| 20 |
from modules.data import styles_mgr
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class AudioSpeechRequest(BaseModel):
|
| 24 |
input: str # 需要合成的文本
|
| 25 |
model: str = "chattts-4w"
|
| 26 |
voice: str = "female2"
|
| 27 |
+
response_format: AudioFormat = "mp3"
|
| 28 |
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
|
| 29 |
seed: int = 42
|
| 30 |
+
|
| 31 |
temperature: float = 0.3
|
| 32 |
+
top_k: int = 20
|
| 33 |
+
top_p: float = 0.7
|
| 34 |
+
|
| 35 |
style: str = ""
|
|
|
|
|
|
|
| 36 |
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
|
| 37 |
spliter_threshold: float = Field(
|
| 38 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
|
|
|
| 40 |
# end of sentence
|
| 41 |
eos: str = "[uv_break]"
|
| 42 |
|
| 43 |
+
enhance: bool = False
|
| 44 |
+
denoise: bool = False
|
| 45 |
+
|
| 46 |
|
| 47 |
async def openai_speech_api(
|
| 48 |
request: AudioSpeechRequest = Body(
|
|
|
|
| 54 |
voice = request.voice
|
| 55 |
style = request.style
|
| 56 |
eos = request.eos
|
| 57 |
+
seed = request.seed
|
| 58 |
+
|
| 59 |
response_format = request.response_format
|
| 60 |
+
if not isinstance(response_format, AudioFormat) and isinstance(
|
| 61 |
+
response_format, str
|
| 62 |
+
):
|
| 63 |
+
response_format = AudioFormat(response_format)
|
| 64 |
+
|
| 65 |
batch_size = request.batch_size
|
| 66 |
spliter_threshold = request.spliter_threshold
|
| 67 |
speed = request.speed
|
|
|
|
| 77 |
except:
|
| 78 |
raise HTTPException(status_code=400, detail="Invalid style.")
|
| 79 |
|
| 80 |
+
ctx_params = api_utils.calc_spk_style(spk=voice, style=style)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
speaker = ctx_params.get("spk")
|
| 83 |
+
if not isinstance(speaker, Speaker):
|
| 84 |
+
raise HTTPException(status_code=400, detail="Invalid voice.")
|
| 85 |
|
| 86 |
+
tts_config = ChatTTSConfig(
|
| 87 |
+
style=style,
|
| 88 |
+
temperature=request.temperature,
|
| 89 |
+
top_k=request.top_k,
|
| 90 |
+
top_p=request.top_p,
|
| 91 |
+
)
|
| 92 |
+
infer_config = InferConfig(
|
| 93 |
+
batch_size=batch_size,
|
| 94 |
+
spliter_threshold=spliter_threshold,
|
| 95 |
+
eos=eos,
|
| 96 |
+
seed=seed,
|
| 97 |
+
)
|
| 98 |
+
adjust_config = AdjustConfig(speaking_rate=speed)
|
| 99 |
+
enhancer_config = EnhancerConfig(
|
| 100 |
+
enabled=request.enhance or request.denoise or False,
|
| 101 |
+
lambd=0.9 if request.denoise else 0.1,
|
| 102 |
+
)
|
| 103 |
+
try:
|
| 104 |
+
handler = TTSHandler(
|
| 105 |
+
text_content=input_text,
|
| 106 |
+
spk=speaker,
|
| 107 |
+
tts_config=tts_config,
|
| 108 |
+
infer_config=infer_config,
|
| 109 |
+
adjust_config=adjust_config,
|
| 110 |
+
enhancer_config=enhancer_config,
|
| 111 |
+
)
|
| 112 |
|
| 113 |
+
buffer = handler.enqueue_to_buffer(response_format)
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
mime_type = f"audio/{response_format.value}"
|
| 116 |
+
if response_format == AudioFormat.mp3:
|
| 117 |
+
mime_type = "audio/mpeg"
|
| 118 |
+
return StreamingResponse(buffer, media_type=mime_type)
|
| 119 |
|
| 120 |
except Exception as e:
|
| 121 |
import logging
|
|
|
|
| 152 |
def setup(app: APIManager):
|
| 153 |
app.post(
|
| 154 |
"/v1/audio/speech",
|
|
|
|
| 155 |
description="""
|
| 156 |
openai api document:
|
| 157 |
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
|
modules/api/impl/refiner_api.py
CHANGED
|
@@ -31,6 +31,7 @@ async def refiner_prompt_post(request: RefineTextRequest):
|
|
| 31 |
text = request.text
|
| 32 |
if request.normalize:
|
| 33 |
text = text_normalize(request.text)
|
|
|
|
| 34 |
refined_text = refiner.refine_text(
|
| 35 |
text=text,
|
| 36 |
prompt=request.prompt,
|
|
|
|
| 31 |
text = request.text
|
| 32 |
if request.normalize:
|
| 33 |
text = text_normalize(request.text)
|
| 34 |
+
# TODO 其实这里可以做 spliter 和 batch 处理
|
| 35 |
refined_text = refiner.refine_text(
|
| 36 |
text=text,
|
| 37 |
prompt=request.prompt,
|
modules/api/impl/ssml_api.py
CHANGED
|
@@ -1,27 +1,22 @@
|
|
| 1 |
from fastapi import HTTPException, Body
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
| 4 |
-
import io
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from fastapi.responses import FileResponse
|
| 7 |
|
| 8 |
|
| 9 |
-
from modules.
|
| 10 |
-
from modules.
|
| 11 |
-
from modules.
|
| 12 |
-
|
| 13 |
-
combine_audio_segments,
|
| 14 |
-
)
|
| 15 |
|
| 16 |
|
| 17 |
-
from modules.api import utils as api_utils
|
| 18 |
-
|
| 19 |
from modules.api.Api import APIManager
|
| 20 |
|
| 21 |
|
| 22 |
class SSMLRequest(BaseModel):
|
| 23 |
ssml: str
|
| 24 |
-
format:
|
| 25 |
|
| 26 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
| 27 |
batch_size: int = 4
|
|
@@ -31,6 +26,9 @@ class SSMLRequest(BaseModel):
|
|
| 31 |
|
| 32 |
spliter_thr: int = 100
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
async def synthesize_ssml_api(
|
| 36 |
request: SSMLRequest = Body(
|
|
@@ -43,6 +41,8 @@ async def synthesize_ssml_api(
|
|
| 43 |
batch_size = request.batch_size
|
| 44 |
eos = request.eos
|
| 45 |
spliter_thr = request.spliter_thr
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if batch_size < 1:
|
| 48 |
raise HTTPException(
|
|
@@ -62,22 +62,27 @@ async def synthesize_ssml_api(
|
|
| 62 |
status_code=400, detail="Format must be 'mp3' or 'wav'."
|
| 63 |
)
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
synthesize = SynthesizeSegments(
|
| 71 |
-
batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
|
| 72 |
)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
except Exception as e:
|
| 83 |
import logging
|
|
|
|
| 1 |
from fastapi import HTTPException, Body
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
|
|
|
| 4 |
from pydantic import BaseModel
|
| 5 |
from fastapi.responses import FileResponse
|
| 6 |
|
| 7 |
|
| 8 |
+
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
+
from modules.api.impl.model.chattts_model import InferConfig
|
| 11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from modules.api.Api import APIManager
|
| 15 |
|
| 16 |
|
| 17 |
class SSMLRequest(BaseModel):
|
| 18 |
ssml: str
|
| 19 |
+
format: AudioFormat = "mp3"
|
| 20 |
|
| 21 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
| 22 |
batch_size: int = 4
|
|
|
|
| 26 |
|
| 27 |
spliter_thr: int = 100
|
| 28 |
|
| 29 |
+
enhancer: EnhancerConfig = EnhancerConfig()
|
| 30 |
+
adjuster: AdjustConfig = AdjustConfig()
|
| 31 |
+
|
| 32 |
|
| 33 |
async def synthesize_ssml_api(
|
| 34 |
request: SSMLRequest = Body(
|
|
|
|
| 41 |
batch_size = request.batch_size
|
| 42 |
eos = request.eos
|
| 43 |
spliter_thr = request.spliter_thr
|
| 44 |
+
enhancer = request.enhancer
|
| 45 |
+
adjuster = request.adjuster
|
| 46 |
|
| 47 |
if batch_size < 1:
|
| 48 |
raise HTTPException(
|
|
|
|
| 62 |
status_code=400, detail="Format must be 'mp3' or 'wav'."
|
| 63 |
)
|
| 64 |
|
| 65 |
+
infer_config = InferConfig(
|
| 66 |
+
batch_size=batch_size,
|
| 67 |
+
spliter_threshold=spliter_thr,
|
| 68 |
+
eos=eos,
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
+
adjust_config = adjuster
|
| 71 |
+
enhancer_config = enhancer
|
| 72 |
+
|
| 73 |
+
handler = SSMLHandler(
|
| 74 |
+
ssml_content=ssml,
|
| 75 |
+
infer_config=infer_config,
|
| 76 |
+
adjust_config=adjust_config,
|
| 77 |
+
enhancer_config=enhancer_config,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
buffer = handler.enqueue_to_buffer(format=request.format)
|
| 81 |
+
|
| 82 |
+
mime_type = f"audio/{format}"
|
| 83 |
+
if format == AudioFormat.mp3:
|
| 84 |
+
mime_type = "audio/mpeg"
|
| 85 |
+
return StreamingResponse(buffer, media_type=mime_type)
|
| 86 |
|
| 87 |
except Exception as e:
|
| 88 |
import logging
|
modules/api/impl/tts_api.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
from fastapi import Depends, HTTPException, Query
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
| 4 |
-
import io
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
import soundfile as sf
|
| 7 |
from fastapi.responses import FileResponse
|
| 8 |
|
| 9 |
|
| 10 |
-
from modules.
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from modules.api import utils as api_utils
|
| 13 |
from modules.api.Api import APIManager
|
| 14 |
-
from modules.
|
| 15 |
|
| 16 |
|
| 17 |
class TTSParams(BaseModel):
|
|
@@ -23,10 +24,10 @@ class TTSParams(BaseModel):
|
|
| 23 |
temperature: float = Query(
|
| 24 |
0.3, description="Temperature for sampling (may be overridden by style or spk)"
|
| 25 |
)
|
| 26 |
-
|
| 27 |
0.5, description="Top P for sampling (may be overridden by style or spk)"
|
| 28 |
)
|
| 29 |
-
|
| 30 |
20, description="Top K for sampling (may be overridden by style or spk)"
|
| 31 |
)
|
| 32 |
seed: int = Query(
|
|
@@ -38,7 +39,14 @@ class TTSParams(BaseModel):
|
|
| 38 |
prefix: str = Query("", description="Text prefix for inference")
|
| 39 |
bs: str = Query("8", description="Batch size for inference")
|
| 40 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
| 41 |
-
eos: str = Query("", description="End of sentence str")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
async def synthesize_tts(params: TTSParams = Depends()):
|
|
@@ -55,18 +63,18 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
| 55 |
status_code=422, detail="Temperature must be between 0 and 1"
|
| 56 |
)
|
| 57 |
|
| 58 |
-
# Validate
|
| 59 |
-
if not (0 <= params.
|
| 60 |
-
raise HTTPException(status_code=422, detail="
|
| 61 |
|
| 62 |
-
# Validate
|
| 63 |
-
if params.
|
| 64 |
raise HTTPException(
|
| 65 |
-
status_code=422, detail="
|
| 66 |
)
|
| 67 |
-
if params.
|
| 68 |
raise HTTPException(
|
| 69 |
-
status_code=422, detail="
|
| 70 |
)
|
| 71 |
|
| 72 |
# Validate format
|
|
@@ -76,11 +84,13 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
| 76 |
detail="Invalid format. Supported formats are mp3 and wav",
|
| 77 |
)
|
| 78 |
|
| 79 |
-
text = text_normalize(params.text, is_end=False)
|
| 80 |
-
|
| 81 |
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
|
| 82 |
|
| 83 |
spk = calc_params.get("spk", params.spk)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
seed = params.seed or calc_params.get("seed", params.seed)
|
| 85 |
temperature = params.temperature or calc_params.get(
|
| 86 |
"temperature", params.temperature
|
|
@@ -93,29 +103,46 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
| 93 |
batch_size = int(params.bs)
|
| 94 |
threshold = int(params.thr)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
temperature=temperature,
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
infer_seed=seed,
|
| 103 |
prompt1=prompt1,
|
| 104 |
prompt2=prompt2,
|
| 105 |
-
|
|
|
|
| 106 |
batch_size=batch_size,
|
| 107 |
spliter_threshold=threshold,
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
except Exception as e:
|
| 121 |
import logging
|
|
|
|
| 1 |
from fastapi import Depends, HTTPException, Query
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
|
|
|
| 4 |
from pydantic import BaseModel
|
|
|
|
| 5 |
from fastapi.responses import FileResponse
|
| 6 |
|
| 7 |
|
| 8 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 12 |
|
| 13 |
from modules.api import utils as api_utils
|
| 14 |
from modules.api.Api import APIManager
|
| 15 |
+
from modules.speaker import Speaker
|
| 16 |
|
| 17 |
|
| 18 |
class TTSParams(BaseModel):
|
|
|
|
| 24 |
temperature: float = Query(
|
| 25 |
0.3, description="Temperature for sampling (may be overridden by style or spk)"
|
| 26 |
)
|
| 27 |
+
top_p: float = Query(
|
| 28 |
0.5, description="Top P for sampling (may be overridden by style or spk)"
|
| 29 |
)
|
| 30 |
+
top_k: int = Query(
|
| 31 |
20, description="Top K for sampling (may be overridden by style or spk)"
|
| 32 |
)
|
| 33 |
seed: int = Query(
|
|
|
|
| 39 |
prefix: str = Query("", description="Text prefix for inference")
|
| 40 |
bs: str = Query("8", description="Batch size for inference")
|
| 41 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
| 42 |
+
eos: str = Query("[uv_break]", description="End of sentence str")
|
| 43 |
+
|
| 44 |
+
enhance: bool = Query(False, description="Enable enhancer")
|
| 45 |
+
denoise: bool = Query(False, description="Enable denoiser")
|
| 46 |
+
|
| 47 |
+
speed: float = Query(1.0, description="Speed of the audio")
|
| 48 |
+
pitch: float = Query(0, description="Pitch of the audio")
|
| 49 |
+
volume_gain: float = Query(0, description="Volume gain of the audio")
|
| 50 |
|
| 51 |
|
| 52 |
async def synthesize_tts(params: TTSParams = Depends()):
|
|
|
|
| 63 |
status_code=422, detail="Temperature must be between 0 and 1"
|
| 64 |
)
|
| 65 |
|
| 66 |
+
# Validate top_p
|
| 67 |
+
if not (0 <= params.top_p <= 1):
|
| 68 |
+
raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
|
| 69 |
|
| 70 |
+
# Validate top_k
|
| 71 |
+
if params.top_k <= 0:
|
| 72 |
raise HTTPException(
|
| 73 |
+
status_code=422, detail="top_k must be a positive integer"
|
| 74 |
)
|
| 75 |
+
if params.top_k > 100:
|
| 76 |
raise HTTPException(
|
| 77 |
+
status_code=422, detail="top_k must be less than or equal to 100"
|
| 78 |
)
|
| 79 |
|
| 80 |
# Validate format
|
|
|
|
| 84 |
detail="Invalid format. Supported formats are mp3 and wav",
|
| 85 |
)
|
| 86 |
|
|
|
|
|
|
|
| 87 |
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
|
| 88 |
|
| 89 |
spk = calc_params.get("spk", params.spk)
|
| 90 |
+
if not isinstance(spk, Speaker):
|
| 91 |
+
raise HTTPException(status_code=422, detail="Invalid speaker")
|
| 92 |
+
|
| 93 |
+
style = calc_params.get("style", params.style)
|
| 94 |
seed = params.seed or calc_params.get("seed", params.seed)
|
| 95 |
temperature = params.temperature or calc_params.get(
|
| 96 |
"temperature", params.temperature
|
|
|
|
| 103 |
batch_size = int(params.bs)
|
| 104 |
threshold = int(params.thr)
|
| 105 |
|
| 106 |
+
tts_config = ChatTTSConfig(
|
| 107 |
+
style=style,
|
| 108 |
temperature=temperature,
|
| 109 |
+
top_k=params.top_k,
|
| 110 |
+
top_p=params.top_p,
|
| 111 |
+
prefix=prefix,
|
|
|
|
| 112 |
prompt1=prompt1,
|
| 113 |
prompt2=prompt2,
|
| 114 |
+
)
|
| 115 |
+
infer_config = InferConfig(
|
| 116 |
batch_size=batch_size,
|
| 117 |
spliter_threshold=threshold,
|
| 118 |
+
eos=eos,
|
| 119 |
+
seed=seed,
|
| 120 |
+
)
|
| 121 |
+
adjust_config = AdjustConfig(
|
| 122 |
+
pitch=params.pitch,
|
| 123 |
+
speed_rate=params.speed,
|
| 124 |
+
volume_gain_db=params.volume_gain,
|
| 125 |
+
)
|
| 126 |
+
enhancer_config = EnhancerConfig(
|
| 127 |
+
enabled=params.enhance or params.denoise or False,
|
| 128 |
+
lambd=0.9 if params.denoise else 0.1,
|
| 129 |
)
|
| 130 |
|
| 131 |
+
handler = TTSHandler(
|
| 132 |
+
text_content=params.text,
|
| 133 |
+
spk=spk,
|
| 134 |
+
tts_config=tts_config,
|
| 135 |
+
infer_config=infer_config,
|
| 136 |
+
adjust_config=adjust_config,
|
| 137 |
+
enhancer_config=enhancer_config,
|
| 138 |
+
)
|
| 139 |
|
| 140 |
+
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
|
|
|
|
| 141 |
|
| 142 |
+
media_type = f"audio/{params.format}"
|
| 143 |
+
if params.format == "mp3":
|
| 144 |
+
media_type = "audio/mpeg"
|
| 145 |
+
return StreamingResponse(buffer, media_type=media_type)
|
| 146 |
|
| 147 |
except Exception as e:
|
| 148 |
import logging
|
modules/api/impl/xtts_v2_api.py
CHANGED
|
@@ -30,8 +30,19 @@ class XTTS_V2_Settings:
|
|
| 30 |
self.top_k = 20
|
| 31 |
self.enable_text_splitting = True
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
class TTSSettingsRequest(BaseModel):
|
|
|
|
| 35 |
stream_chunk_size: int
|
| 36 |
temperature: float
|
| 37 |
speed: float
|
|
@@ -41,6 +52,15 @@ class TTSSettingsRequest(BaseModel):
|
|
| 41 |
top_k: int
|
| 42 |
enable_text_splitting: bool
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class SynthesisRequest(BaseModel):
|
| 46 |
text: str
|
|
@@ -79,17 +99,22 @@ def setup(app: APIManager):
|
|
| 79 |
|
| 80 |
text = text_normalize(text, is_end=True)
|
| 81 |
sample_rate, audio_data = synthesize_audio(
|
| 82 |
-
|
| 83 |
-
temperature=XTTSV2.temperature,
|
| 84 |
# length_penalty=XTTSV2.length_penalty,
|
| 85 |
# repetition_penalty=XTTSV2.repetition_penalty,
|
|
|
|
|
|
|
| 86 |
top_P=XTTSV2.top_p,
|
| 87 |
top_K=XTTSV2.top_k,
|
| 88 |
spk=spk,
|
| 89 |
-
spliter_threshold=XTTSV2.
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
if XTTSV2.speed:
|
|
@@ -145,6 +170,8 @@ def setup(app: APIManager):
|
|
| 145 |
)
|
| 146 |
|
| 147 |
XTTSV2.stream_chunk_size = request.stream_chunk_size
|
|
|
|
|
|
|
| 148 |
XTTSV2.temperature = request.temperature
|
| 149 |
XTTSV2.speed = request.speed
|
| 150 |
XTTSV2.length_penalty = request.length_penalty
|
|
@@ -152,6 +179,25 @@ def setup(app: APIManager):
|
|
| 152 |
XTTSV2.top_p = request.top_p
|
| 153 |
XTTSV2.top_k = request.top_k
|
| 154 |
XTTSV2.enable_text_splitting = request.enable_text_splitting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
return {"message": "Settings successfully applied"}
|
| 156 |
except Exception as e:
|
| 157 |
if isinstance(e, HTTPException):
|
|
|
|
| 30 |
self.top_k = 20
|
| 31 |
self.enable_text_splitting = True
|
| 32 |
|
| 33 |
+
# 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
|
| 34 |
+
self.batch_size = 4
|
| 35 |
+
self.eos = "[uv_break]"
|
| 36 |
+
self.infer_seed = 42
|
| 37 |
+
self.use_decoder = True
|
| 38 |
+
self.prompt1 = ""
|
| 39 |
+
self.prompt2 = ""
|
| 40 |
+
self.prefix = ""
|
| 41 |
+
self.spliter_threshold = 100
|
| 42 |
+
|
| 43 |
|
| 44 |
class TTSSettingsRequest(BaseModel):
|
| 45 |
+
# 这个 stream_chunk 现在当作 spliter_threshold 用
|
| 46 |
stream_chunk_size: int
|
| 47 |
temperature: float
|
| 48 |
speed: float
|
|
|
|
| 52 |
top_k: int
|
| 53 |
enable_text_splitting: bool
|
| 54 |
|
| 55 |
+
batch_size: int = None
|
| 56 |
+
eos: str = None
|
| 57 |
+
infer_seed: int = None
|
| 58 |
+
use_decoder: bool = None
|
| 59 |
+
prompt1: str = None
|
| 60 |
+
prompt2: str = None
|
| 61 |
+
prefix: str = None
|
| 62 |
+
spliter_threshold: int = None
|
| 63 |
+
|
| 64 |
|
| 65 |
class SynthesisRequest(BaseModel):
|
| 66 |
text: str
|
|
|
|
| 99 |
|
| 100 |
text = text_normalize(text, is_end=True)
|
| 101 |
sample_rate, audio_data = synthesize_audio(
|
| 102 |
+
# TODO: 这两个参数现在用不着...但是其实gpt是可以用的
|
|
|
|
| 103 |
# length_penalty=XTTSV2.length_penalty,
|
| 104 |
# repetition_penalty=XTTSV2.repetition_penalty,
|
| 105 |
+
text=text,
|
| 106 |
+
temperature=XTTSV2.temperature,
|
| 107 |
top_P=XTTSV2.top_p,
|
| 108 |
top_K=XTTSV2.top_k,
|
| 109 |
spk=spk,
|
| 110 |
+
spliter_threshold=XTTSV2.spliter_threshold,
|
| 111 |
+
batch_size=XTTSV2.batch_size,
|
| 112 |
+
end_of_sentence=XTTSV2.eos,
|
| 113 |
+
infer_seed=XTTSV2.infer_seed,
|
| 114 |
+
use_decoder=XTTSV2.use_decoder,
|
| 115 |
+
prompt1=XTTSV2.prompt1,
|
| 116 |
+
prompt2=XTTSV2.prompt2,
|
| 117 |
+
prefix=XTTSV2.prefix,
|
| 118 |
)
|
| 119 |
|
| 120 |
if XTTSV2.speed:
|
|
|
|
| 170 |
)
|
| 171 |
|
| 172 |
XTTSV2.stream_chunk_size = request.stream_chunk_size
|
| 173 |
+
XTTSV2.spliter_threshold = request.stream_chunk_size
|
| 174 |
+
|
| 175 |
XTTSV2.temperature = request.temperature
|
| 176 |
XTTSV2.speed = request.speed
|
| 177 |
XTTSV2.length_penalty = request.length_penalty
|
|
|
|
| 179 |
XTTSV2.top_p = request.top_p
|
| 180 |
XTTSV2.top_k = request.top_k
|
| 181 |
XTTSV2.enable_text_splitting = request.enable_text_splitting
|
| 182 |
+
|
| 183 |
+
# TODO: checker
|
| 184 |
+
if request.batch_size:
|
| 185 |
+
XTTSV2.batch_size = request.batch_size
|
| 186 |
+
if request.eos:
|
| 187 |
+
XTTSV2.eos = request.eos
|
| 188 |
+
if request.infer_seed:
|
| 189 |
+
XTTSV2.infer_seed = request.infer_seed
|
| 190 |
+
if request.use_decoder:
|
| 191 |
+
XTTSV2.use_decoder = request.use_decoder
|
| 192 |
+
if request.prompt1:
|
| 193 |
+
XTTSV2.prompt1 = request.prompt1
|
| 194 |
+
if request.prompt2:
|
| 195 |
+
XTTSV2.prompt2 = request.prompt2
|
| 196 |
+
if request.prefix:
|
| 197 |
+
XTTSV2.prefix = request.prefix
|
| 198 |
+
if request.spliter_threshold:
|
| 199 |
+
XTTSV2.spliter_threshold = request.spliter_threshold
|
| 200 |
+
|
| 201 |
return {"message": "Settings successfully applied"}
|
| 202 |
except Exception as e:
|
| 203 |
if isinstance(e, HTTPException):
|
modules/api/utils.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
from typing import Any, Union
|
| 3 |
|
| 4 |
-
import torch
|
| 5 |
|
| 6 |
-
from modules.speaker import
|
| 7 |
|
| 8 |
|
| 9 |
from modules.data import styles_mgr
|
|
@@ -13,18 +12,10 @@ from pydub import AudioSegment
|
|
| 13 |
from modules.ssml import merge_prompt
|
| 14 |
|
| 15 |
|
| 16 |
-
from enum import Enum
|
| 17 |
-
|
| 18 |
-
|
| 19 |
class ParamsTypeError(Exception):
|
| 20 |
pass
|
| 21 |
|
| 22 |
|
| 23 |
-
class AudioFormat(str, Enum):
|
| 24 |
-
mp3 = "mp3"
|
| 25 |
-
wav = "wav"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
class BaseResponse(BaseModel):
|
| 29 |
message: str
|
| 30 |
data: Any
|
|
@@ -35,7 +26,7 @@ def success_response(data: Any, message: str = "ok") -> BaseResponse:
|
|
| 35 |
|
| 36 |
|
| 37 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
| 38 |
-
audio = AudioSegment.from_wav(
|
| 39 |
wav_data,
|
| 40 |
)
|
| 41 |
return audio.export(format="mp3", bitrate=bitrate)
|
|
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
from typing import Any, Union
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
+
from modules.speaker import speaker_mgr
|
| 6 |
|
| 7 |
|
| 8 |
from modules.data import styles_mgr
|
|
|
|
| 12 |
from modules.ssml import merge_prompt
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
class ParamsTypeError(Exception):
|
| 16 |
pass
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class BaseResponse(BaseModel):
|
| 20 |
message: str
|
| 21 |
data: Any
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
| 29 |
+
audio: AudioSegment = AudioSegment.from_wav(
|
| 30 |
wav_data,
|
| 31 |
)
|
| 32 |
return audio.export(format="mp3", bitrate=bitrate)
|
modules/devices/devices.py
CHANGED
|
@@ -127,6 +127,12 @@ def reset_device():
|
|
| 127 |
global dtype_gpt
|
| 128 |
global dtype_decoder
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if not config.runtime_env_vars.no_half:
|
| 131 |
dtype = torch.float16
|
| 132 |
dtype_dvae = torch.float16
|
|
@@ -144,7 +150,7 @@ def reset_device():
|
|
| 144 |
|
| 145 |
logger.info("Using full precision: torch.float32")
|
| 146 |
|
| 147 |
-
if config.runtime_env_vars.use_cpu
|
| 148 |
device = cpu
|
| 149 |
else:
|
| 150 |
device = get_optimal_device()
|
|
|
|
| 127 |
global dtype_gpt
|
| 128 |
global dtype_decoder
|
| 129 |
|
| 130 |
+
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
|
| 131 |
+
logger.warning(
|
| 132 |
+
"Cannot use half precision with CPU, using full precision instead"
|
| 133 |
+
)
|
| 134 |
+
config.runtime_env_vars.no_half = True
|
| 135 |
+
|
| 136 |
if not config.runtime_env_vars.no_half:
|
| 137 |
dtype = torch.float16
|
| 138 |
dtype_dvae = torch.float16
|
|
|
|
| 150 |
|
| 151 |
logger.info("Using full precision: torch.float32")
|
| 152 |
|
| 153 |
+
if "all" in config.runtime_env_vars.use_cpu:
|
| 154 |
device = cpu
|
| 155 |
else:
|
| 156 |
device = get_optimal_device()
|
modules/finetune/train_speaker.py
CHANGED
|
@@ -45,9 +45,10 @@ def train_speaker_embeddings(
|
|
| 45 |
)
|
| 46 |
for speaker in dataset.speakers
|
| 47 |
}
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
| 53 |
AUDIO_EOS_TOKEN_ID = 0
|
|
@@ -166,13 +167,13 @@ def train_speaker_embeddings(
|
|
| 166 |
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
| 167 |
)
|
| 168 |
loss = audio_loss
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
|
| 177 |
gpt_gen_mel_specs = decoder_decoder(
|
| 178 |
audio_hidden_states[:, :-1].transpose(1, 2)
|
|
@@ -181,7 +182,12 @@ def train_speaker_embeddings(
|
|
| 181 |
loss += 0.01 * mse_loss
|
| 182 |
|
| 183 |
optimizer.zero_grad()
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
|
| 186 |
optimizer.step()
|
| 187 |
logger.meters["loss"].update(loss.item(), n=batch_size)
|
|
@@ -203,6 +209,7 @@ if __name__ == "__main__":
|
|
| 203 |
from modules.speaker import Speaker
|
| 204 |
|
| 205 |
config.runtime_env_vars.no_half = True
|
|
|
|
| 206 |
devices.reset_device()
|
| 207 |
|
| 208 |
parser = argparse.ArgumentParser()
|
|
|
|
| 45 |
)
|
| 46 |
for speaker in dataset.speakers
|
| 47 |
}
|
| 48 |
+
|
| 49 |
+
for speaker_embed in speaker_embeds.values():
|
| 50 |
+
std, mean = chat.pretrain_models["spk_stat"].chunk(2)
|
| 51 |
+
speaker_embed.data = speaker_embed.data * std + mean
|
| 52 |
|
| 53 |
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
| 54 |
AUDIO_EOS_TOKEN_ID = 0
|
|
|
|
| 167 |
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
| 168 |
)
|
| 169 |
loss = audio_loss
|
| 170 |
+
|
| 171 |
+
text_logits = gpt.head_text(text_hidden_states)
|
| 172 |
+
text_loss = loss_fn(
|
| 173 |
+
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
|
| 174 |
+
)
|
| 175 |
+
loss += text_loss
|
| 176 |
+
logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
|
| 177 |
|
| 178 |
gpt_gen_mel_specs = decoder_decoder(
|
| 179 |
audio_hidden_states[:, :-1].transpose(1, 2)
|
|
|
|
| 182 |
loss += 0.01 * mse_loss
|
| 183 |
|
| 184 |
optimizer.zero_grad()
|
| 185 |
+
|
| 186 |
+
if train_text:
|
| 187 |
+
# just for test
|
| 188 |
+
text_loss.backward()
|
| 189 |
+
else:
|
| 190 |
+
loss.backward()
|
| 191 |
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
|
| 192 |
optimizer.step()
|
| 193 |
logger.meters["loss"].update(loss.item(), n=batch_size)
|
|
|
|
| 209 |
from modules.speaker import Speaker
|
| 210 |
|
| 211 |
config.runtime_env_vars.no_half = True
|
| 212 |
+
config.runtime_env_vars.use_cpu = []
|
| 213 |
devices.reset_device()
|
| 214 |
|
| 215 |
parser = argparse.ArgumentParser()
|
modules/prompts/news_oral_prompt.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 任务要求
|
| 2 |
+
任务: 新闻稿口播化
|
| 3 |
+
|
| 4 |
+
你需要将一个新闻稿改写为口语化的口播文本
|
| 5 |
+
同时,适当的添加一些 附语言 标签为文本增加多样性
|
| 6 |
+
|
| 7 |
+
目前可以使用的附语言标签如下:
|
| 8 |
+
- `[laugh]`: 表示笑声
|
| 9 |
+
- `[uv_break]`: 表示无声停顿
|
| 10 |
+
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
| 11 |
+
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
| 12 |
+
|
| 13 |
+
# 输入
|
| 14 |
+
{{USER_INPUT}}
|
modules/prompts/podcast_prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
TODO
|
modules/ssml_parser/SSMLParser.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
| 1 |
from lxml import etree
|
| 2 |
|
| 3 |
|
| 4 |
-
from typing import
|
| 5 |
import logging
|
| 6 |
|
| 7 |
-
from modules.data import styles_mgr
|
| 8 |
-
from modules.speaker import speaker_mgr
|
| 9 |
from box import Box
|
| 10 |
-
import copy
|
| 11 |
|
| 12 |
|
| 13 |
class SSMLContext(Box):
|
|
|
|
| 1 |
from lxml import etree
|
| 2 |
|
| 3 |
|
| 4 |
+
from typing import List, Union
|
| 5 |
import logging
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from box import Box
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class SSMLContext(Box):
|
modules/webui/speaker/speaker_editor.py
CHANGED
|
@@ -25,7 +25,7 @@ def speaker_editor_ui():
|
|
| 25 |
spk: Speaker = Speaker.from_file(spk_file)
|
| 26 |
spk.name = name
|
| 27 |
spk.gender = gender
|
| 28 |
-
spk.
|
| 29 |
|
| 30 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
| 31 |
torch.save(spk, tmp_file)
|
|
|
|
| 25 |
spk: Speaker = Speaker.from_file(spk_file)
|
| 26 |
spk.name = name
|
| 27 |
spk.gender = gender
|
| 28 |
+
spk.describe = desc
|
| 29 |
|
| 30 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
| 31 |
torch.save(spk, tmp_file)
|
modules/webui/speaker/speaker_merger.py
CHANGED
|
@@ -38,12 +38,8 @@ def merge_spk(
|
|
| 38 |
tensor_c = spk_to_tensor(spk_c)
|
| 39 |
tensor_d = spk_to_tensor(spk_d)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
or tensor_b is not None
|
| 44 |
-
or tensor_c is not None
|
| 45 |
-
or tensor_d is not None
|
| 46 |
-
), "At least one speaker should be selected"
|
| 47 |
|
| 48 |
merge_tensor = torch.zeros_like(
|
| 49 |
tensor_a
|
|
|
|
| 38 |
tensor_c = spk_to_tensor(spk_c)
|
| 39 |
tensor_d = spk_to_tensor(spk_d)
|
| 40 |
|
| 41 |
+
if tensor_a is None and tensor_b is None and tensor_c is None and tensor_d is None:
|
| 42 |
+
raise gr.Error("At least one speaker should be selected")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
merge_tensor = torch.zeros_like(
|
| 45 |
tensor_a
|