Spaces:
Running
Running
zhzluke96
commited on
Commit
·
8c22399
1
Parent(s):
4554b6b
update
Browse files- launch.py +150 -58
- modules/api/Api.py +6 -20
- modules/api/impl/refiner_api.py +6 -1
- modules/api/impl/speaker_api.py +13 -13
- modules/api/impl/tts_api.py +0 -2
- modules/gradio_dcls_fix.py +6 -0
- modules/webui/app.py +12 -4
- modules/webui/js/localization.js +22 -3
- modules/webui/tts_tab.py +1 -1
- webui.py +79 -63
launch.py
CHANGED
|
@@ -1,109 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from modules import config
|
|
|
|
| 3 |
from modules import generate_audio as generate
|
| 4 |
-
|
| 5 |
-
from functools import lru_cache
|
| 6 |
-
from typing import Callable
|
| 7 |
-
|
| 8 |
from modules.api.Api import APIManager
|
| 9 |
|
| 10 |
from modules.api.impl import (
|
| 11 |
-
|
| 12 |
tts_api,
|
| 13 |
ssml_api,
|
| 14 |
google_api,
|
| 15 |
openai_api,
|
| 16 |
refiner_api,
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
|
|
|
|
|
|
| 19 |
torch._dynamo.config.cache_size_limit = 64
|
| 20 |
torch._dynamo.config.suppress_errors = True
|
| 21 |
torch.set_float32_matmul_precision("high")
|
| 22 |
|
| 23 |
|
| 24 |
-
def create_api():
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
return
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
-
def wrapper(*args, **kwargs):
|
| 44 |
-
if condition(*args, **kwargs):
|
| 45 |
-
return cached_func(*args, **kwargs)
|
| 46 |
-
else:
|
| 47 |
-
return func(*args, **kwargs)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
import argparse
|
| 56 |
-
import uvicorn
|
| 57 |
-
|
| 58 |
-
parser = argparse.ArgumentParser(
|
| 59 |
-
description="Start the FastAPI server with command line arguments"
|
| 60 |
)
|
| 61 |
parser.add_argument(
|
| 62 |
-
"--
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
parser.add_argument(
|
| 65 |
-
"--
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
parser.add_argument(
|
| 68 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
| 71 |
parser.add_argument(
|
| 72 |
"--lru_size",
|
| 73 |
type=int,
|
| 74 |
default=64,
|
| 75 |
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
| 76 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
parser.add_argument(
|
| 78 |
"--cors_origin",
|
| 79 |
type=str,
|
| 80 |
-
default="*",
|
| 81 |
help="Allowed CORS origins. Use '*' to allow all origins.",
|
| 82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
args = parser.parse_args()
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
-
def should_cache(*args, **kwargs):
|
| 93 |
-
spk_seed = kwargs.get("spk_seed", -1)
|
| 94 |
-
infer_seed = kwargs.get("infer_seed", -1)
|
| 95 |
-
return spk_seed != -1 and infer_seed != -1
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
-
api = create_api()
|
| 104 |
config.api = api
|
| 105 |
|
| 106 |
-
if
|
| 107 |
-
api.set_cors(allow_origins=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
uvicorn.run(
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
logging.basicConfig(
|
| 5 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
| 6 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from modules.devices import devices
|
| 10 |
+
import argparse
|
| 11 |
+
import uvicorn
|
| 12 |
+
|
| 13 |
import torch
|
| 14 |
from modules import config
|
| 15 |
+
from modules.utils import env
|
| 16 |
from modules import generate_audio as generate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from modules.api.Api import APIManager
|
| 18 |
|
| 19 |
from modules.api.impl import (
|
| 20 |
+
style_api,
|
| 21 |
tts_api,
|
| 22 |
ssml_api,
|
| 23 |
google_api,
|
| 24 |
openai_api,
|
| 25 |
refiner_api,
|
| 26 |
+
speaker_api,
|
| 27 |
+
ping_api,
|
| 28 |
+
models_api,
|
| 29 |
)
|
| 30 |
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
torch._dynamo.config.cache_size_limit = 64
|
| 34 |
torch._dynamo.config.suppress_errors = True
|
| 35 |
torch.set_float32_matmul_precision("high")
|
| 36 |
|
| 37 |
|
| 38 |
+
def create_api(app, no_docs=False, exclude=[]):
|
| 39 |
+
app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)
|
| 40 |
|
| 41 |
+
ping_api.setup(app_mgr)
|
| 42 |
+
models_api.setup(app_mgr)
|
| 43 |
+
style_api.setup(app_mgr)
|
| 44 |
+
speaker_api.setup(app_mgr)
|
| 45 |
+
tts_api.setup(app_mgr)
|
| 46 |
+
ssml_api.setup(app_mgr)
|
| 47 |
+
google_api.setup(app_mgr)
|
| 48 |
+
openai_api.setup(app_mgr)
|
| 49 |
+
refiner_api.setup(app_mgr)
|
| 50 |
|
| 51 |
+
return app_mgr
|
| 52 |
|
| 53 |
|
| 54 |
+
def get_and_update_env(*args):
|
| 55 |
+
val = env.get_env_or_arg(*args)
|
| 56 |
+
key = args[1]
|
| 57 |
+
config.runtime_env_vars[key] = val
|
| 58 |
+
return val
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def setup_model_args(parser: argparse.ArgumentParser):
|
| 62 |
+
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--half",
|
| 65 |
+
action="store_true",
|
| 66 |
+
help="Enable half precision for model inference",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
parser.add_argument(
|
| 69 |
+
"--off_tqdm",
|
| 70 |
+
action="store_true",
|
| 71 |
+
help="Disable tqdm progress bar",
|
| 72 |
)
|
| 73 |
parser.add_argument(
|
| 74 |
+
"--device_id",
|
| 75 |
+
type=str,
|
| 76 |
+
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
| 77 |
+
default=None,
|
| 78 |
)
|
| 79 |
parser.add_argument(
|
| 80 |
+
"--use_cpu",
|
| 81 |
+
nargs="+",
|
| 82 |
+
help="use CPU as torch device for specified modules",
|
| 83 |
+
default=[],
|
| 84 |
+
type=str.lower,
|
| 85 |
)
|
|
|
|
| 86 |
parser.add_argument(
|
| 87 |
"--lru_size",
|
| 88 |
type=int,
|
| 89 |
default=64,
|
| 90 |
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
| 91 |
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def setup_api_args(parser: argparse.ArgumentParser):
|
| 95 |
+
parser.add_argument("--api_host", type=str, help="Host to run the server on")
|
| 96 |
+
parser.add_argument("--api_port", type=int, help="Port to run the server on")
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--reload", action="store_true", help="Enable auto-reload for development"
|
| 99 |
+
)
|
| 100 |
parser.add_argument(
|
| 101 |
"--cors_origin",
|
| 102 |
type=str,
|
|
|
|
| 103 |
help="Allowed CORS origins. Use '*' to allow all origins.",
|
| 104 |
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--no_playground",
|
| 107 |
+
action="store_true",
|
| 108 |
+
help="Disable the playground entry",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--no_docs",
|
| 112 |
+
action="store_true",
|
| 113 |
+
help="Disable the documentation entry",
|
| 114 |
+
)
|
| 115 |
+
# 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--exclude",
|
| 118 |
+
type=str,
|
| 119 |
+
help="Exclude the specified API from the server",
|
| 120 |
+
)
|
| 121 |
|
|
|
|
| 122 |
|
| 123 |
+
def process_model_args(args):
|
| 124 |
+
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
| 125 |
+
compile = get_and_update_env(args, "compile", False, bool)
|
| 126 |
+
device_id = get_and_update_env(args, "device_id", None, str)
|
| 127 |
+
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
| 128 |
+
half = get_and_update_env(args, "half", False, bool)
|
| 129 |
+
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
| 130 |
|
| 131 |
+
generate.setup_lru_cache()
|
| 132 |
+
devices.reset_device()
|
| 133 |
+
devices.first_time_calculation()
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
def process_api_args(args, app):
|
| 137 |
+
cors_origin = get_and_update_env(args, "cors_origin", "*", str)
|
| 138 |
+
no_playground = get_and_update_env(args, "no_playground", False, bool)
|
| 139 |
+
no_docs = get_and_update_env(args, "no_docs", False, bool)
|
| 140 |
+
exclude = get_and_update_env(args, "exclude", "", str)
|
| 141 |
|
| 142 |
+
api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
|
| 143 |
config.api = api
|
| 144 |
|
| 145 |
+
if cors_origin:
|
| 146 |
+
api.set_cors(allow_origins=[cors_origin])
|
| 147 |
+
|
| 148 |
+
if not no_playground:
|
| 149 |
+
api.setup_playground()
|
| 150 |
+
|
| 151 |
+
if compile:
|
| 152 |
+
logger.info("Model compile is enabled")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
app_description = """
|
| 156 |
+
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
| 157 |
+
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
| 158 |
+
|
| 159 |
+
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
| 160 |
+
|
| 161 |
+
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
| 162 |
+
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
| 163 |
+
|
| 164 |
+
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
| 165 |
+
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
| 166 |
+
"""
|
| 167 |
+
app_title = "ChatTTS Forge API"
|
| 168 |
+
app_version = "0.1.0"
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
import dotenv
|
| 172 |
+
from fastapi import FastAPI
|
| 173 |
+
|
| 174 |
+
dotenv.load_dotenv(
|
| 175 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
parser = argparse.ArgumentParser(
|
| 179 |
+
description="Start the FastAPI server with command line arguments"
|
| 180 |
+
)
|
| 181 |
+
setup_api_args(parser)
|
| 182 |
+
setup_model_args(parser)
|
| 183 |
+
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
|
| 186 |
+
app = FastAPI(
|
| 187 |
+
title=app_title,
|
| 188 |
+
description=app_description,
|
| 189 |
+
version=app_version,
|
| 190 |
+
redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
|
| 191 |
+
docs_url=None if config.runtime_env_vars.no_docs else "/docs",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
process_model_args(args)
|
| 195 |
+
process_api_args(args, app)
|
| 196 |
+
|
| 197 |
+
host = get_and_update_env(args, "api_host", "0.0.0.0", str)
|
| 198 |
+
port = get_and_update_env(args, "api_port", 7870, int)
|
| 199 |
+
reload = get_and_update_env(args, "reload", False, bool)
|
| 200 |
|
| 201 |
+
uvicorn.run(app, host=host, port=port, reload=reload)
|
modules/api/Api.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
|
| 4 |
import logging
|
|
@@ -24,25 +24,8 @@ def is_excluded(path, exclude_patterns):
|
|
| 24 |
|
| 25 |
|
| 26 |
class APIManager:
|
| 27 |
-
def __init__(self, no_docs=False, exclude_patterns=[]):
|
| 28 |
-
self.app =
|
| 29 |
-
title="ChatTTS Forge API",
|
| 30 |
-
description="""
|
| 31 |
-
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
| 32 |
-
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
| 33 |
-
|
| 34 |
-
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
| 35 |
-
|
| 36 |
-
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
| 37 |
-
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
| 38 |
-
|
| 39 |
-
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
| 40 |
-
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
| 41 |
-
""",
|
| 42 |
-
version="0.1.0",
|
| 43 |
-
redoc_url=None if no_docs else "/redoc",
|
| 44 |
-
docs_url=None if no_docs else "/docs",
|
| 45 |
-
)
|
| 46 |
self.registered_apis = {}
|
| 47 |
self.logger = logging.getLogger(__name__)
|
| 48 |
self.exclude = exclude_patterns
|
|
@@ -57,6 +40,8 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
|
|
| 57 |
allow_methods: list = ["*"],
|
| 58 |
allow_headers: list = ["*"],
|
| 59 |
):
|
|
|
|
|
|
|
| 60 |
self.app.add_middleware(
|
| 61 |
CORSMiddleware,
|
| 62 |
allow_origins=allow_origins,
|
|
@@ -64,6 +49,7 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
|
|
| 64 |
allow_methods=allow_methods,
|
| 65 |
allow_headers=allow_headers,
|
| 66 |
)
|
|
|
|
| 67 |
|
| 68 |
def setup_playground(self):
|
| 69 |
app = self.app
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
|
| 4 |
import logging
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class APIManager:
|
| 27 |
+
def __init__(self, app: FastAPI, no_docs=False, exclude_patterns=[]):
|
| 28 |
+
self.app = app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
self.registered_apis = {}
|
| 30 |
self.logger = logging.getLogger(__name__)
|
| 31 |
self.exclude = exclude_patterns
|
|
|
|
| 40 |
allow_methods: list = ["*"],
|
| 41 |
allow_headers: list = ["*"],
|
| 42 |
):
|
| 43 |
+
# reset middleware stack
|
| 44 |
+
self.app.middleware_stack = None
|
| 45 |
self.app.add_middleware(
|
| 46 |
CORSMiddleware,
|
| 47 |
allow_origins=allow_origins,
|
|
|
|
| 49 |
allow_methods=allow_methods,
|
| 50 |
allow_headers=allow_headers,
|
| 51 |
)
|
| 52 |
+
self.app.build_middleware_stack()
|
| 53 |
|
| 54 |
def setup_playground(self):
|
| 55 |
app = self.app
|
modules/api/impl/refiner_api.py
CHANGED
|
@@ -7,6 +7,7 @@ from modules import refiner
|
|
| 7 |
|
| 8 |
from modules.api import utils as api_utils
|
| 9 |
from modules.api.Api import APIManager
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class RefineTextRequest(BaseModel):
|
|
@@ -18,6 +19,7 @@ class RefineTextRequest(BaseModel):
|
|
| 18 |
temperature: float = 0.7
|
| 19 |
repetition_penalty: float = 1.0
|
| 20 |
max_new_token: int = 384
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
async def refiner_prompt_post(request: RefineTextRequest):
|
|
@@ -26,8 +28,11 @@ async def refiner_prompt_post(request: RefineTextRequest):
|
|
| 26 |
"""
|
| 27 |
|
| 28 |
try:
|
|
|
|
|
|
|
|
|
|
| 29 |
refined_text = refiner.refine_text(
|
| 30 |
-
text=
|
| 31 |
prompt=request.prompt,
|
| 32 |
seed=request.seed,
|
| 33 |
top_P=request.top_P,
|
|
|
|
| 7 |
|
| 8 |
from modules.api import utils as api_utils
|
| 9 |
from modules.api.Api import APIManager
|
| 10 |
+
from modules.normalization import text_normalize
|
| 11 |
|
| 12 |
|
| 13 |
class RefineTextRequest(BaseModel):
|
|
|
|
| 19 |
temperature: float = 0.7
|
| 20 |
repetition_penalty: float = 1.0
|
| 21 |
max_new_token: int = 384
|
| 22 |
+
normalize: bool = True
|
| 23 |
|
| 24 |
|
| 25 |
async def refiner_prompt_post(request: RefineTextRequest):
|
|
|
|
| 28 |
"""
|
| 29 |
|
| 30 |
try:
|
| 31 |
+
text = request.text
|
| 32 |
+
if request.normalize:
|
| 33 |
+
text = text_normalize(request.text)
|
| 34 |
refined_text = refiner.refine_text(
|
| 35 |
+
text=text,
|
| 36 |
prompt=request.prompt,
|
| 37 |
seed=request.seed,
|
| 38 |
top_P=request.top_P,
|
modules/api/impl/speaker_api.py
CHANGED
|
@@ -35,10 +35,14 @@ def setup(app: APIManager):
|
|
| 35 |
|
| 36 |
@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
|
| 37 |
async def list_speakers():
|
| 38 |
-
return
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
|
| 44 |
async def update_speakers(request: SpeakersUpdate):
|
|
@@ -59,7 +63,8 @@ def setup(app: APIManager):
|
|
| 59 |
# number array => Tensor
|
| 60 |
speaker.emb = torch.tensor(spk["tensor"])
|
| 61 |
speaker_mgr.save_all()
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
|
| 65 |
async def create_speaker(request: CreateSpeaker):
|
|
@@ -88,12 +93,7 @@ def setup(app: APIManager):
|
|
| 88 |
raise HTTPException(
|
| 89 |
status_code=400, detail="Missing tensor or seed in request"
|
| 90 |
)
|
| 91 |
-
return
|
| 92 |
-
|
| 93 |
-
@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
|
| 94 |
-
async def refresh_speakers():
|
| 95 |
-
speaker_mgr.refresh_speakers()
|
| 96 |
-
return {"message": "ok"}
|
| 97 |
|
| 98 |
@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
|
| 99 |
async def update_speaker(request: UpdateSpeaker):
|
|
@@ -113,11 +113,11 @@ def setup(app: APIManager):
|
|
| 113 |
# number array => Tensor
|
| 114 |
speaker.emb = torch.tensor(request.tensor)
|
| 115 |
speaker_mgr.update_speaker(speaker)
|
| 116 |
-
return
|
| 117 |
|
| 118 |
@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
|
| 119 |
async def speaker_detail(request: SpeakerDetail):
|
| 120 |
speaker = speaker_mgr.get_speaker_by_id(request.id)
|
| 121 |
if speaker is None:
|
| 122 |
raise HTTPException(status_code=404, detail="Speaker not found")
|
| 123 |
-
return
|
|
|
|
| 35 |
|
| 36 |
@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
|
| 37 |
async def list_speakers():
|
| 38 |
+
return api_utils.success_response(
|
| 39 |
+
[spk.to_json() for spk in speaker_mgr.list_speakers()]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
@app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
|
| 43 |
+
async def refresh_speakers():
|
| 44 |
+
speaker_mgr.refresh_speakers()
|
| 45 |
+
return api_utils.success_response(None)
|
| 46 |
|
| 47 |
@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
|
| 48 |
async def update_speakers(request: SpeakersUpdate):
|
|
|
|
| 63 |
# number array => Tensor
|
| 64 |
speaker.emb = torch.tensor(spk["tensor"])
|
| 65 |
speaker_mgr.save_all()
|
| 66 |
+
|
| 67 |
+
return api_utils.success_response(None)
|
| 68 |
|
| 69 |
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
|
| 70 |
async def create_speaker(request: CreateSpeaker):
|
|
|
|
| 93 |
raise HTTPException(
|
| 94 |
status_code=400, detail="Missing tensor or seed in request"
|
| 95 |
)
|
| 96 |
+
return api_utils.success_response(speaker.to_json())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
|
| 99 |
async def update_speaker(request: UpdateSpeaker):
|
|
|
|
| 113 |
# number array => Tensor
|
| 114 |
speaker.emb = torch.tensor(request.tensor)
|
| 115 |
speaker_mgr.update_speaker(speaker)
|
| 116 |
+
return api_utils.success_response(None)
|
| 117 |
|
| 118 |
@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
|
| 119 |
async def speaker_detail(request: SpeakerDetail):
|
| 120 |
speaker = speaker_mgr.get_speaker_by_id(request.id)
|
| 121 |
if speaker is None:
|
| 122 |
raise HTTPException(status_code=404, detail="Speaker not found")
|
| 123 |
+
return api_utils.success_response(speaker.to_json(with_emb=request.with_emb))
|
modules/api/impl/tts_api.py
CHANGED
|
@@ -9,8 +9,6 @@ from fastapi.responses import FileResponse
|
|
| 9 |
|
| 10 |
from modules.normalization import text_normalize
|
| 11 |
|
| 12 |
-
from modules import generate_audio as generate
|
| 13 |
-
|
| 14 |
from modules.api import utils as api_utils
|
| 15 |
from modules.api.Api import APIManager
|
| 16 |
from modules.synthesize_audio import synthesize_audio
|
|
|
|
| 9 |
|
| 10 |
from modules.normalization import text_normalize
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from modules.api import utils as api_utils
|
| 13 |
from modules.api.Api import APIManager
|
| 14 |
from modules.synthesize_audio import synthesize_audio
|
modules/gradio_dcls_fix.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def dcls_patch():
|
| 2 |
+
from gradio import data_classes
|
| 3 |
+
|
| 4 |
+
data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
|
| 5 |
+
"type": "object",
|
| 6 |
+
}
|
modules/webui/app.py
CHANGED
|
@@ -46,11 +46,19 @@ def create_app_footer():
|
|
| 46 |
|
| 47 |
config.versions.gradio_version = gradio_version
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
gr.Markdown(
|
| 50 |
-
|
| 51 |
-
🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
| 52 |
-
version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
|
| 53 |
-
""",
|
| 54 |
elem_classes=["no-translate"],
|
| 55 |
)
|
| 56 |
|
|
|
|
| 46 |
|
| 47 |
config.versions.gradio_version = gradio_version
|
| 48 |
|
| 49 |
+
footer_items = ["🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)"]
|
| 50 |
+
footer_items.append(
|
| 51 |
+
f"version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit})"
|
| 52 |
+
)
|
| 53 |
+
footer_items.append(f"branch: `{git_branch}`")
|
| 54 |
+
footer_items.append(f"python: `{python_version}`")
|
| 55 |
+
footer_items.append(f"torch: `{torch_version}`")
|
| 56 |
+
|
| 57 |
+
if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
|
| 58 |
+
footer_items.append(f"[API](/docs)")
|
| 59 |
+
|
| 60 |
gr.Markdown(
|
| 61 |
+
" | ".join(footer_items),
|
|
|
|
|
|
|
|
|
|
| 62 |
elem_classes=["no-translate"],
|
| 63 |
)
|
| 64 |
|
modules/webui/js/localization.js
CHANGED
|
@@ -163,6 +163,23 @@ function localizeWholePage() {
|
|
| 163 |
}
|
| 164 |
}
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
document.addEventListener("DOMContentLoaded", function () {
|
| 167 |
if (!hasLocalization()) {
|
| 168 |
return;
|
|
@@ -170,9 +187,11 @@ document.addEventListener("DOMContentLoaded", function () {
|
|
| 170 |
|
| 171 |
onUiUpdate(function (m) {
|
| 172 |
m.forEach(function (mutation) {
|
| 173 |
-
mutation.addedNodes
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
| 176 |
});
|
| 177 |
});
|
| 178 |
|
|
|
|
| 163 |
}
|
| 164 |
}
|
| 165 |
|
| 166 |
+
/**
|
| 167 |
+
*
|
| 168 |
+
* @param {HTMLElement} node
|
| 169 |
+
*/
|
| 170 |
+
function isNeedTranslate(node) {
|
| 171 |
+
if (!node) return false;
|
| 172 |
+
if (!(node instanceof HTMLElement)) return true;
|
| 173 |
+
while (node.parentElement !== document.body) {
|
| 174 |
+
if (node.classList.contains("no-translate")) {
|
| 175 |
+
return false;
|
| 176 |
+
}
|
| 177 |
+
node = node.parentElement;
|
| 178 |
+
if (!node) break;
|
| 179 |
+
}
|
| 180 |
+
return true;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
document.addEventListener("DOMContentLoaded", function () {
|
| 184 |
if (!hasLocalization()) {
|
| 185 |
return;
|
|
|
|
| 187 |
|
| 188 |
onUiUpdate(function (m) {
|
| 189 |
m.forEach(function (mutation) {
|
| 190 |
+
Array.from(mutation.addedNodes)
|
| 191 |
+
.filter(isNeedTranslate)
|
| 192 |
+
.forEach(function (node) {
|
| 193 |
+
processNode(node);
|
| 194 |
+
});
|
| 195 |
});
|
| 196 |
});
|
| 197 |
|
modules/webui/tts_tab.py
CHANGED
|
@@ -96,7 +96,7 @@ def create_tts_interface():
|
|
| 96 |
)
|
| 97 |
|
| 98 |
gr.Markdown("📝Speaker info")
|
| 99 |
-
infos = gr.Markdown("empty")
|
| 100 |
|
| 101 |
spk_file_upload.change(
|
| 102 |
fn=load_spk_info,
|
|
|
|
| 96 |
)
|
| 97 |
|
| 98 |
gr.Markdown("📝Speaker info")
|
| 99 |
+
infos = gr.Markdown("empty", elem_classes=["no-translate"])
|
| 100 |
|
| 101 |
spk_file_upload.change(
|
| 102 |
fn=load_spk_info,
|
webui.py
CHANGED
|
@@ -1,27 +1,30 @@
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from modules.webui import webui_config
|
| 12 |
from modules.webui.app import webui_init, create_interface
|
| 13 |
-
|
| 14 |
-
from modules import
|
| 15 |
|
| 16 |
-
|
| 17 |
-
import argparse
|
| 18 |
-
import dotenv
|
| 19 |
|
| 20 |
-
dotenv.load_dotenv(
|
| 21 |
-
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
|
| 22 |
-
)
|
| 23 |
|
| 24 |
-
|
| 25 |
parser.add_argument("--server_name", type=str, help="server name")
|
| 26 |
parser.add_argument("--server_port", type=int, help="server port")
|
| 27 |
parser.add_argument(
|
|
@@ -29,16 +32,6 @@ if __name__ == "__main__":
|
|
| 29 |
)
|
| 30 |
parser.add_argument("--debug", action="store_true", help="enable debug mode")
|
| 31 |
parser.add_argument("--auth", type=str, help="username:password for authentication")
|
| 32 |
-
parser.add_argument(
|
| 33 |
-
"--half",
|
| 34 |
-
action="store_true",
|
| 35 |
-
help="Enable half precision for model inference",
|
| 36 |
-
)
|
| 37 |
-
parser.add_argument(
|
| 38 |
-
"--off_tqdm",
|
| 39 |
-
action="store_true",
|
| 40 |
-
help="Disable tqdm progress bar",
|
| 41 |
-
)
|
| 42 |
parser.add_argument(
|
| 43 |
"--tts_max_len",
|
| 44 |
type=int,
|
|
@@ -54,58 +47,39 @@ if __name__ == "__main__":
|
|
| 54 |
type=int,
|
| 55 |
help="Max batch size for TTS",
|
| 56 |
)
|
| 57 |
-
parser.add_argument(
|
| 58 |
-
"--lru_size",
|
| 59 |
-
type=int,
|
| 60 |
-
default=64,
|
| 61 |
-
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
| 62 |
-
)
|
| 63 |
-
parser.add_argument(
|
| 64 |
-
"--device_id",
|
| 65 |
-
type=str,
|
| 66 |
-
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
| 67 |
-
default=None,
|
| 68 |
-
)
|
| 69 |
-
parser.add_argument(
|
| 70 |
-
"--use_cpu",
|
| 71 |
-
nargs="+",
|
| 72 |
-
help="use CPU as torch device for specified modules",
|
| 73 |
-
default=[],
|
| 74 |
-
type=str.lower,
|
| 75 |
-
)
|
| 76 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
| 77 |
# webui_Experimental
|
| 78 |
parser.add_argument(
|
| 79 |
"--webui_experimental",
|
| 80 |
action="store_true",
|
| 81 |
help="Enable webui_experimental features",
|
| 82 |
)
|
| 83 |
-
|
| 84 |
parser.add_argument(
|
| 85 |
"--language",
|
| 86 |
type=str,
|
| 87 |
help="Set the default language for the webui",
|
| 88 |
)
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
def get_and_update_env(*args):
|
| 92 |
-
val = env.get_env_or_arg(*args)
|
| 93 |
-
key = args[1]
|
| 94 |
-
config.runtime_env_vars[key] = val
|
| 95 |
-
return val
|
| 96 |
|
|
|
|
| 97 |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
|
| 98 |
server_port = get_and_update_env(args, "server_port", 7860, int)
|
| 99 |
share = get_and_update_env(args, "share", False, bool)
|
| 100 |
debug = get_and_update_env(args, "debug", False, bool)
|
| 101 |
auth = get_and_update_env(args, "auth", None, str)
|
| 102 |
-
half = get_and_update_env(args, "half", False, bool)
|
| 103 |
-
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
| 104 |
-
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
| 105 |
-
device_id = get_and_update_env(args, "device_id", None, str)
|
| 106 |
-
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
| 107 |
-
compile = get_and_update_env(args, "compile", False, bool)
|
| 108 |
language = get_and_update_env(args, "language", "zh-CN", str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
webui_config.experimental = get_and_update_env(
|
| 111 |
args, "webui_experimental", False, bool
|
|
@@ -120,15 +94,57 @@ if __name__ == "__main__":
|
|
| 120 |
if auth:
|
| 121 |
auth = tuple(auth.split(":"))
|
| 122 |
|
| 123 |
-
|
| 124 |
-
devices.reset_device()
|
| 125 |
-
devices.first_time_calculation()
|
| 126 |
-
|
| 127 |
-
demo.queue().launch(
|
| 128 |
server_name=server_name,
|
| 129 |
server_port=server_port,
|
| 130 |
share=share,
|
| 131 |
debug=debug,
|
| 132 |
auth=auth,
|
| 133 |
show_api=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
|
| 4 |
+
logging.basicConfig(
|
| 5 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
| 6 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 7 |
+
)
|
| 8 |
|
| 9 |
+
from launch import (
|
| 10 |
+
get_and_update_env,
|
| 11 |
+
setup_api_args,
|
| 12 |
+
setup_model_args,
|
| 13 |
+
process_api_args,
|
| 14 |
+
process_model_args,
|
| 15 |
+
app_description,
|
| 16 |
+
app_title,
|
| 17 |
+
app_version,
|
| 18 |
+
)
|
| 19 |
from modules.webui import webui_config
|
| 20 |
from modules.webui.app import webui_init, create_interface
|
| 21 |
+
import argparse
|
| 22 |
+
from modules.gradio_dcls_fix import dcls_patch
|
| 23 |
|
| 24 |
+
dcls_patch()
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
def setup_webui_args(parser: argparse.ArgumentParser):
|
| 28 |
parser.add_argument("--server_name", type=str, help="server name")
|
| 29 |
parser.add_argument("--server_port", type=int, help="server port")
|
| 30 |
parser.add_argument(
|
|
|
|
| 32 |
)
|
| 33 |
parser.add_argument("--debug", action="store_true", help="enable debug mode")
|
| 34 |
parser.add_argument("--auth", type=str, help="username:password for authentication")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
parser.add_argument(
|
| 36 |
"--tts_max_len",
|
| 37 |
type=int,
|
|
|
|
| 47 |
type=int,
|
| 48 |
help="Max batch size for TTS",
|
| 49 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# webui_Experimental
|
| 51 |
parser.add_argument(
|
| 52 |
"--webui_experimental",
|
| 53 |
action="store_true",
|
| 54 |
help="Enable webui_experimental features",
|
| 55 |
)
|
|
|
|
| 56 |
parser.add_argument(
|
| 57 |
"--language",
|
| 58 |
type=str,
|
| 59 |
help="Set the default language for the webui",
|
| 60 |
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--api",
|
| 63 |
+
action="store_true",
|
| 64 |
+
help="use api=True to launch the API together with the webui (run launch.py for only API server)",
|
| 65 |
+
)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
def process_webui_args(args):
|
| 69 |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
|
| 70 |
server_port = get_and_update_env(args, "server_port", 7860, int)
|
| 71 |
share = get_and_update_env(args, "share", False, bool)
|
| 72 |
debug = get_and_update_env(args, "debug", False, bool)
|
| 73 |
auth = get_and_update_env(args, "auth", None, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
language = get_and_update_env(args, "language", "zh-CN", str)
|
| 75 |
+
api = get_and_update_env(args, "api", "zh-CN", str)
|
| 76 |
+
|
| 77 |
+
webui_config.experimental = get_and_update_env(
|
| 78 |
+
args, "webui_experimental", False, bool
|
| 79 |
+
)
|
| 80 |
+
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
|
| 81 |
+
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
| 82 |
+
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|
| 83 |
|
| 84 |
webui_config.experimental = get_and_update_env(
|
| 85 |
args, "webui_experimental", False, bool
|
|
|
|
| 94 |
if auth:
|
| 95 |
auth = tuple(auth.split(":"))
|
| 96 |
|
| 97 |
+
app, local_url, share_url = demo.queue().launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
server_name=server_name,
|
| 99 |
server_port=server_port,
|
| 100 |
share=share,
|
| 101 |
debug=debug,
|
| 102 |
auth=auth,
|
| 103 |
show_api=False,
|
| 104 |
+
prevent_thread_lock=True,
|
| 105 |
+
app_kwargs={
|
| 106 |
+
"title": app_title,
|
| 107 |
+
"description": app_description,
|
| 108 |
+
"version": app_version,
|
| 109 |
+
# "redoc_url": (
|
| 110 |
+
# None
|
| 111 |
+
# if api is False
|
| 112 |
+
# else None if config.runtime_env_vars.no_docs else "/redoc"
|
| 113 |
+
# ),
|
| 114 |
+
# "docs_url": (
|
| 115 |
+
# None
|
| 116 |
+
# if api is False
|
| 117 |
+
# else None if config.runtime_env_vars.no_docs else "/docs"
|
| 118 |
+
# ),
|
| 119 |
+
"docs_url": "/docs",
|
| 120 |
+
},
|
| 121 |
+
)
|
| 122 |
+
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
| 123 |
+
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
| 124 |
+
# running web ui and do whatever the attacker wants, including installing an extension and
|
| 125 |
+
# running its code. We disable this here. Suggested by RyotaK.
|
| 126 |
+
app.user_middleware = [
|
| 127 |
+
x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
if api:
|
| 131 |
+
process_api_args(args, app)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
import dotenv
|
| 136 |
+
|
| 137 |
+
dotenv.load_dotenv(
|
| 138 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
|
| 139 |
)
|
| 140 |
+
|
| 141 |
+
parser = argparse.ArgumentParser(description="Gradio App")
|
| 142 |
+
|
| 143 |
+
setup_webui_args(parser)
|
| 144 |
+
setup_model_args(parser)
|
| 145 |
+
setup_api_args(parser)
|
| 146 |
+
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
process_model_args(args)
|
| 150 |
+
process_webui_args(args)
|