Spaces:
Running
on
L4
Running
on
L4
Init hf space integration
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- app.py +317 -0
- fish_speech/callbacks/__init__.py +3 -0
- fish_speech/callbacks/grad_norm.py +113 -0
- fish_speech/configs/base.yaml +86 -0
- fish_speech/configs/model/dual_ar_2_codebook_large.yaml +9 -0
- fish_speech/configs/model/dual_ar_2_codebook_medium.yaml +9 -0
- fish_speech/configs/model/dual_ar_2_codebook_small.yaml +13 -0
- fish_speech/configs/model/naive_2_codebook_small.yaml +12 -0
- fish_speech/configs/text2semantic_finetune.yaml +79 -0
- fish_speech/configs/text2semantic_finetune_lora.yaml +13 -0
- fish_speech/configs/text2semantic_pretrain.yaml +74 -0
- fish_speech/configs/text2semantic_sft.yaml +87 -0
- fish_speech/configs/vqgan_finetune.yaml +135 -0
- fish_speech/configs/vqgan_pretrain.yaml +139 -0
- fish_speech/datasets/protos/text-data.proto +24 -0
- fish_speech/datasets/protos/text_data_pb2.py +33 -0
- fish_speech/datasets/protos/text_data_stream.py +36 -0
- fish_speech/datasets/text.py +661 -0
- fish_speech/datasets/vqgan.py +145 -0
- fish_speech/models/text2semantic/__init__.py +3 -0
- fish_speech/models/text2semantic/lit_module.py +344 -0
- fish_speech/models/text2semantic/llama.py +595 -0
- fish_speech/models/vqgan/__init__.py +3 -0
- fish_speech/models/vqgan/lit_module.py +442 -0
- fish_speech/models/vqgan/modules/discriminator.py +44 -0
- fish_speech/models/vqgan/modules/firefly.py +538 -0
- fish_speech/models/vqgan/modules/fsq.py +139 -0
- fish_speech/models/vqgan/modules/reference.py +113 -0
- fish_speech/models/vqgan/modules/wavenet.py +225 -0
- fish_speech/models/vqgan/spectrogram.py +122 -0
- fish_speech/models/vqgan/utils.py +94 -0
- fish_speech/scheduler.py +22 -0
- fish_speech/text/__init__.py +3 -0
- fish_speech/text/clean.py +73 -0
- fish_speech/train.py +135 -0
- fish_speech/utils/__init__.py +21 -0
- fish_speech/utils/braceexpand.py +217 -0
- fish_speech/utils/file.py +119 -0
- fish_speech/utils/instantiators.py +50 -0
- fish_speech/utils/logger.py +55 -0
- fish_speech/utils/logging_utils.py +48 -0
- fish_speech/utils/rich_utils.py +96 -0
- fish_speech/utils/utils.py +114 -0
- packages.txt +10 -0
- pyrightconfig.json +6 -0
- requirements.txt +24 -0
- setup.sh +18 -0
- tools/extract_model.py +21 -0
- tools/llama/build_dataset.py +165 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
checkpoints
|
app.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess as sp
|
| 2 |
+
sp.check_call("setup.sh", shell=True)
|
| 3 |
+
|
| 4 |
+
import html
|
| 5 |
+
import os
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import librosa
|
| 12 |
+
import spaces
|
| 13 |
+
import torch
|
| 14 |
+
from loguru import logger
|
| 15 |
+
from torchaudio import functional as AF
|
| 16 |
+
from transformers import AutoTokenizer
|
| 17 |
+
|
| 18 |
+
from tools.llama.generate import generate_long
|
| 19 |
+
from tools.llama.generate import load_model as load_llama_model
|
| 20 |
+
from tools.vqgan.inference import load_model as load_vqgan_model
|
| 21 |
+
|
| 22 |
+
# Make einx happy
|
| 23 |
+
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
HEADER_MD = """# Fish Speech
|
| 27 |
+
|
| 28 |
+
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
| 29 |
+
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
| 30 |
+
|
| 31 |
+
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
|
| 32 |
+
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
|
| 33 |
+
|
| 34 |
+
Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.
|
| 35 |
+
相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
| 36 |
+
|
| 37 |
+
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
| 38 |
+
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_html_error_message(error):
|
| 45 |
+
return f"""
|
| 46 |
+
<div style="color: red; font-weight: bold;">
|
| 47 |
+
{html.escape(error)}
|
| 48 |
+
</div>
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@spaces.GPU
|
| 53 |
+
def inference(
|
| 54 |
+
text,
|
| 55 |
+
enable_reference_audio,
|
| 56 |
+
reference_audio,
|
| 57 |
+
reference_text,
|
| 58 |
+
max_new_tokens,
|
| 59 |
+
chunk_length,
|
| 60 |
+
top_k,
|
| 61 |
+
top_p,
|
| 62 |
+
repetition_penalty,
|
| 63 |
+
temperature,
|
| 64 |
+
speaker=None,
|
| 65 |
+
):
|
| 66 |
+
if len(reference_text) > 100:
|
| 67 |
+
return None, "Ref text is too long, please keep it under 100 characters."
|
| 68 |
+
|
| 69 |
+
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
| 70 |
+
return None, "Text is too long, please keep it under 1000 characters."
|
| 71 |
+
|
| 72 |
+
# Parse reference audio aka prompt
|
| 73 |
+
if enable_reference_audio and reference_audio is not None:
|
| 74 |
+
# reference_audio_sr, reference_audio_content = reference_audio
|
| 75 |
+
reference_audio_content, _ = librosa.load(
|
| 76 |
+
reference_audio, sr=vqgan_model.sampling_rate, mono=True
|
| 77 |
+
)
|
| 78 |
+
audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
|
| 79 |
+
None, None, :
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
logger.info(
|
| 83 |
+
f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# VQ Encoder
|
| 87 |
+
audio_lengths = torch.tensor(
|
| 88 |
+
[audios.shape[2]], device=vqgan_model.device, dtype=torch.long
|
| 89 |
+
)
|
| 90 |
+
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
| 91 |
+
|
| 92 |
+
# LLAMA Inference
|
| 93 |
+
result = generate_long(
|
| 94 |
+
model=llama_model,
|
| 95 |
+
tokenizer=llama_tokenizer,
|
| 96 |
+
device=vqgan_model.device,
|
| 97 |
+
decode_one_token=decode_one_token,
|
| 98 |
+
max_new_tokens=max_new_tokens,
|
| 99 |
+
text=text,
|
| 100 |
+
top_k=int(top_k) if top_k > 0 else None,
|
| 101 |
+
top_p=top_p,
|
| 102 |
+
repetition_penalty=repetition_penalty,
|
| 103 |
+
temperature=temperature,
|
| 104 |
+
compile=args.compile,
|
| 105 |
+
iterative_prompt=chunk_length > 0,
|
| 106 |
+
chunk_length=chunk_length,
|
| 107 |
+
max_length=args.max_length,
|
| 108 |
+
speaker=speaker if speaker else None,
|
| 109 |
+
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
| 110 |
+
prompt_text=reference_text if enable_reference_audio else None,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
codes = next(result)
|
| 114 |
+
|
| 115 |
+
# VQGAN Inference
|
| 116 |
+
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
| 117 |
+
fake_audios = vqgan_model.decode(
|
| 118 |
+
indices=codes[None], feature_lengths=feature_lengths, return_audios=True
|
| 119 |
+
)[0, 0]
|
| 120 |
+
|
| 121 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
| 122 |
+
|
| 123 |
+
return (vqgan_model.sampling_rate, fake_audios), None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_app():
|
| 127 |
+
with gr.Blocks(theme=gr.themes.Base()) as app:
|
| 128 |
+
gr.Markdown(HEADER_MD)
|
| 129 |
+
|
| 130 |
+
# Use light theme by default
|
| 131 |
+
app.load(
|
| 132 |
+
None,
|
| 133 |
+
None,
|
| 134 |
+
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Inference
|
| 138 |
+
with gr.Row():
|
| 139 |
+
with gr.Column(scale=3):
|
| 140 |
+
text = gr.Textbox(
|
| 141 |
+
label="Input Text / 输入文本",
|
| 142 |
+
placeholder=TEXTBOX_PLACEHOLDER,
|
| 143 |
+
lines=15,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
with gr.Row():
|
| 147 |
+
with gr.Tab(label="Advanced Config / 高级参数"):
|
| 148 |
+
chunk_length = gr.Slider(
|
| 149 |
+
label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
|
| 150 |
+
minimum=0,
|
| 151 |
+
maximum=100,
|
| 152 |
+
value=30,
|
| 153 |
+
step=8,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
max_new_tokens = gr.Slider(
|
| 157 |
+
label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
|
| 158 |
+
minimum=128,
|
| 159 |
+
maximum=512,
|
| 160 |
+
value=512, # 0 means no limit
|
| 161 |
+
step=8,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
top_k = gr.Slider(
|
| 165 |
+
label="Top-K", minimum=0, maximum=5, value=0, step=1
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
top_p = gr.Slider(
|
| 169 |
+
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
repetition_penalty = gr.Slider(
|
| 173 |
+
label="Repetition Penalty",
|
| 174 |
+
minimum=0,
|
| 175 |
+
maximum=2,
|
| 176 |
+
value=1.5,
|
| 177 |
+
step=0.01,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
temperature = gr.Slider(
|
| 181 |
+
label="Temperature",
|
| 182 |
+
minimum=0,
|
| 183 |
+
maximum=2,
|
| 184 |
+
value=0.7,
|
| 185 |
+
step=0.01,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# speaker = gr.Textbox(
|
| 189 |
+
# label="Speaker / 说话人",
|
| 190 |
+
# placeholder="Type name of the speaker / 输入说话人的名称",
|
| 191 |
+
# lines=1,
|
| 192 |
+
# )
|
| 193 |
+
|
| 194 |
+
with gr.Tab(label="Reference Audio / 参考音频"):
|
| 195 |
+
gr.Markdown(
|
| 196 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
enable_reference_audio = gr.Checkbox(
|
| 200 |
+
label="Enable Reference Audio / 启用参考音频",
|
| 201 |
+
)
|
| 202 |
+
reference_audio = gr.Audio(
|
| 203 |
+
label="Reference Audio / 参考音频",
|
| 204 |
+
value="docs/assets/audios/0_input.wav",
|
| 205 |
+
type="filepath",
|
| 206 |
+
)
|
| 207 |
+
reference_text = gr.Textbox(
|
| 208 |
+
label="Reference Text / 参考文本",
|
| 209 |
+
placeholder="参考文本",
|
| 210 |
+
lines=1,
|
| 211 |
+
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
with gr.Column(scale=3):
|
| 215 |
+
with gr.Row():
|
| 216 |
+
error = gr.HTML(label="Error Message / 错误信息")
|
| 217 |
+
with gr.Row():
|
| 218 |
+
audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
|
| 219 |
+
|
| 220 |
+
with gr.Row():
|
| 221 |
+
with gr.Column(scale=3):
|
| 222 |
+
generate = gr.Button(
|
| 223 |
+
value="\U0001F3A7 Generate / 合成", variant="primary"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# # Submit
|
| 227 |
+
generate.click(
|
| 228 |
+
inference,
|
| 229 |
+
[
|
| 230 |
+
text,
|
| 231 |
+
enable_reference_audio,
|
| 232 |
+
reference_audio,
|
| 233 |
+
reference_text,
|
| 234 |
+
max_new_tokens,
|
| 235 |
+
chunk_length,
|
| 236 |
+
top_k,
|
| 237 |
+
top_p,
|
| 238 |
+
repetition_penalty,
|
| 239 |
+
temperature,
|
| 240 |
+
# speaker,
|
| 241 |
+
],
|
| 242 |
+
[audio, error],
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return app
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def parse_args():
|
| 249 |
+
parser = ArgumentParser()
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--llama-checkpoint-path",
|
| 252 |
+
type=Path,
|
| 253 |
+
default="checkpoints/text2semantic-medium-v1-2k.pth",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--vqgan-checkpoint-path",
|
| 260 |
+
type=Path,
|
| 261 |
+
default="checkpoints/vq-gan-group-fsq-2x1024.pth",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
|
| 264 |
+
parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
| 265 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 266 |
+
parser.add_argument("--half", action="store_true")
|
| 267 |
+
parser.add_argument("--max-length", type=int, default=2048)
|
| 268 |
+
parser.add_argument("--compile", action="store_true")
|
| 269 |
+
parser.add_argument("--max-gradio-length", type=int, default=1024)
|
| 270 |
+
|
| 271 |
+
return parser.parse_args()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
args = parse_args()
|
| 276 |
+
|
| 277 |
+
args.precision = torch.half if args.half else torch.bfloat16
|
| 278 |
+
|
| 279 |
+
logger.info("Loading Llama model...")
|
| 280 |
+
llama_model, decode_one_token = load_llama_model(
|
| 281 |
+
config_name=args.llama_config_name,
|
| 282 |
+
checkpoint_path=args.llama_checkpoint_path,
|
| 283 |
+
device=args.device,
|
| 284 |
+
precision=args.precision,
|
| 285 |
+
max_length=args.max_length,
|
| 286 |
+
compile=args.compile,
|
| 287 |
+
)
|
| 288 |
+
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
| 289 |
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 290 |
+
|
| 291 |
+
vqgan_model = load_vqgan_model(
|
| 292 |
+
config_name=args.vqgan_config_name,
|
| 293 |
+
checkpoint_path=args.vqgan_checkpoint_path,
|
| 294 |
+
device=args.device,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
logger.info("VQ-GAN model loaded, warming up...")
|
| 298 |
+
|
| 299 |
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
| 300 |
+
inference(
|
| 301 |
+
text="Hello, world!",
|
| 302 |
+
enable_reference_audio=False,
|
| 303 |
+
reference_audio=None,
|
| 304 |
+
reference_text="",
|
| 305 |
+
max_new_tokens=0,
|
| 306 |
+
chunk_length=0,
|
| 307 |
+
top_k=0, # 0 means no limit
|
| 308 |
+
top_p=0.7,
|
| 309 |
+
repetition_penalty=1.5,
|
| 310 |
+
temperature=0.7,
|
| 311 |
+
speaker=None,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
logger.info("Warming up done, launching the web UI...")
|
| 315 |
+
|
| 316 |
+
app = build_app()
|
| 317 |
+
app.launch(show_api=False)
|
fish_speech/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .grad_norm import GradNormMonitor
|
| 2 |
+
|
| 3 |
+
__all__ = ["GradNormMonitor"]
|
fish_speech/callbacks/grad_norm.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import lightning.pytorch as pl
|
| 4 |
+
import torch
|
| 5 |
+
from lightning import LightningModule, Trainer
|
| 6 |
+
from lightning.pytorch.callbacks import Callback
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
from torch.utils._foreach_utils import (
|
| 9 |
+
_group_tensors_by_device_and_dtype,
|
| 10 |
+
_has_foreach_support,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def grad_norm(
|
| 16 |
+
parameters: Union[Tensor, list[Tensor]],
|
| 17 |
+
norm_type: float = 2.0,
|
| 18 |
+
) -> float:
|
| 19 |
+
"""
|
| 20 |
+
Returns the norm of the gradients of the given parameters.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
| 24 |
+
single Tensor that will have gradients normalized
|
| 25 |
+
norm_type (float): type of the used p-norm.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
| 29 |
+
""" # noqa: E501
|
| 30 |
+
|
| 31 |
+
if isinstance(parameters, Tensor):
|
| 32 |
+
parameters = [parameters]
|
| 33 |
+
|
| 34 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
| 35 |
+
if len(grads) == 0:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
first_device = grads[0].device
|
| 39 |
+
grouped_grads: dict[
|
| 40 |
+
tuple[torch.device, torch.dtype], list[list[Tensor]]
|
| 41 |
+
] = _group_tensors_by_device_and_dtype(
|
| 42 |
+
[[g.detach() for g in grads]]
|
| 43 |
+
) # type: ignore[assignment]
|
| 44 |
+
|
| 45 |
+
norms = []
|
| 46 |
+
for (device, _), ([grads], _) in grouped_grads.items():
|
| 47 |
+
if _has_foreach_support(grads, device=device):
|
| 48 |
+
norms.extend(torch._foreach_norm(grads, norm_type))
|
| 49 |
+
else:
|
| 50 |
+
norms.extend([torch.norm(g, norm_type) for g in grads])
|
| 51 |
+
|
| 52 |
+
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class GradNormMonitor(Callback):
|
| 56 |
+
"""
|
| 57 |
+
Callback that computes the gradient norm of the model parameters.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
norm_type: float = 2.0,
|
| 63 |
+
logging_interval: str = "step",
|
| 64 |
+
sub_module: Optional[Union[str, list[str]]] = None,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
norm_type (float): type of the used p-norm.
|
| 69 |
+
logging_interval (str): "step" or "epoch".
|
| 70 |
+
"""
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.norm_type = norm_type
|
| 74 |
+
self.logging_interval = logging_interval
|
| 75 |
+
self.sub_module = sub_module
|
| 76 |
+
|
| 77 |
+
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
|
| 78 |
+
"""
|
| 79 |
+
Computes the gradient norm of the model parameters and logs it to the logger.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
trainer (Trainer): The trainer object
|
| 83 |
+
model (LightningModule): The current lightningModule
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
lightning_model = model
|
| 87 |
+
|
| 88 |
+
if self.sub_module is None:
|
| 89 |
+
return self.log_sub_module_grad_norm(lightning_model, model, "")
|
| 90 |
+
|
| 91 |
+
sub_modules = self.sub_module
|
| 92 |
+
if isinstance(sub_modules, str):
|
| 93 |
+
sub_modules = [sub_modules]
|
| 94 |
+
|
| 95 |
+
for sub_module in sub_modules:
|
| 96 |
+
self.log_sub_module_grad_norm(
|
| 97 |
+
lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def log_sub_module_grad_norm(
|
| 101 |
+
self, lightning_model: LightningModule, model: nn.Module, path: str
|
| 102 |
+
) -> None:
|
| 103 |
+
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
| 104 |
+
if grad_norm_val is None:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
on_step = self.logging_interval == "step"
|
| 108 |
+
lightning_model.log(
|
| 109 |
+
f"train{path}/grad_norm",
|
| 110 |
+
grad_norm_val,
|
| 111 |
+
on_step=on_step,
|
| 112 |
+
on_epoch=not on_step,
|
| 113 |
+
)
|
fish_speech/configs/base.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base configuration for training a model
|
| 2 |
+
paths:
|
| 3 |
+
run_dir: results/${project}
|
| 4 |
+
ckpt_dir: ${paths.run_dir}/checkpoints
|
| 5 |
+
|
| 6 |
+
hydra:
|
| 7 |
+
run:
|
| 8 |
+
dir: ${paths.run_dir}
|
| 9 |
+
|
| 10 |
+
# Lightning Trainer
|
| 11 |
+
trainer:
|
| 12 |
+
_target_: lightning.pytorch.trainer.Trainer
|
| 13 |
+
|
| 14 |
+
default_root_dir: ${paths.run_dir}
|
| 15 |
+
accelerator: gpu
|
| 16 |
+
num_nodes: 1
|
| 17 |
+
devices: auto
|
| 18 |
+
strategy:
|
| 19 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
| 20 |
+
|
| 21 |
+
precision: bf16-mixed
|
| 22 |
+
|
| 23 |
+
# disable validation by epoch end
|
| 24 |
+
check_val_every_n_epoch: null
|
| 25 |
+
val_check_interval: 5000
|
| 26 |
+
max_steps: 100_000
|
| 27 |
+
|
| 28 |
+
# Use torch.backends.cudnn.benchmark to speed up training
|
| 29 |
+
benchmark: true
|
| 30 |
+
|
| 31 |
+
# Callbacks
|
| 32 |
+
callbacks:
|
| 33 |
+
model_checkpoint:
|
| 34 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 35 |
+
dirpath: ${paths.ckpt_dir}
|
| 36 |
+
filename: "step_{step:09d}"
|
| 37 |
+
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 38 |
+
save_top_k: 5 # save 5 latest checkpoints
|
| 39 |
+
monitor: step # use step to monitor checkpoints
|
| 40 |
+
mode: max # save the latest checkpoint with the highest global_step
|
| 41 |
+
every_n_epochs: null # don't save checkpoints by epoch end
|
| 42 |
+
every_n_train_steps: 5000 # save checkpoints every 5000 steps
|
| 43 |
+
auto_insert_metric_name: false
|
| 44 |
+
|
| 45 |
+
model_summary:
|
| 46 |
+
_target_: lightning.pytorch.callbacks.ModelSummary
|
| 47 |
+
max_depth: 2 # the maximum depth of layer nesting that the summary will include
|
| 48 |
+
|
| 49 |
+
learning_rate_monitor:
|
| 50 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
| 51 |
+
logging_interval: step
|
| 52 |
+
log_momentum: false
|
| 53 |
+
|
| 54 |
+
grad_norm_monitor:
|
| 55 |
+
_target_: fish_speech.callbacks.GradNormMonitor
|
| 56 |
+
norm_type: 2
|
| 57 |
+
logging_interval: step
|
| 58 |
+
|
| 59 |
+
# Logger
|
| 60 |
+
logger:
|
| 61 |
+
tensorboard:
|
| 62 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
| 63 |
+
save_dir: "${paths.run_dir}/tensorboard/"
|
| 64 |
+
name: null
|
| 65 |
+
log_graph: false
|
| 66 |
+
default_hp_metric: true
|
| 67 |
+
prefix: ""
|
| 68 |
+
|
| 69 |
+
# wandb:
|
| 70 |
+
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
|
| 71 |
+
# # name: "" # name of the run (normally generated by wandb)
|
| 72 |
+
# save_dir: "${paths.run_dir}"
|
| 73 |
+
# offline: False
|
| 74 |
+
# id: null # pass correct id to resume experiment!
|
| 75 |
+
# anonymous: null # enable anonymous logging
|
| 76 |
+
# project: "fish-speech"
|
| 77 |
+
# log_model: False # upload lightning ckpts
|
| 78 |
+
# prefix: "" # a string to put at the beginning of metric keys
|
| 79 |
+
# # entity: "" # set to name of your wandb team
|
| 80 |
+
# group: ""
|
| 81 |
+
# tags: ["vq", "hq", "finetune"]
|
| 82 |
+
# job_type: ""
|
| 83 |
+
|
| 84 |
+
# Loop
|
| 85 |
+
train: true
|
| 86 |
+
test: false
|
fish_speech/configs/model/dual_ar_2_codebook_large.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- dual_ar_2_codebook_small
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
config:
|
| 6 |
+
n_layer: 30
|
| 7 |
+
n_fast_layer: 6
|
| 8 |
+
n_head: 24
|
| 9 |
+
dim: 1536
|
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- dual_ar_2_codebook_small
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
config:
|
| 6 |
+
n_layer: 24
|
| 7 |
+
n_fast_layer: 6
|
| 8 |
+
n_head: 16
|
| 9 |
+
dim: 1024
|
fish_speech/configs/model/dual_ar_2_codebook_small.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: fish_speech.models.text2semantic.llama.DualARTransformer
|
| 2 |
+
config:
|
| 3 |
+
_target_: fish_speech.models.text2semantic.llama.DualARModelArgs
|
| 4 |
+
max_seq_len: ${max_length}
|
| 5 |
+
vocab_size: 264 # pad 262 to 8x
|
| 6 |
+
n_layer: 12
|
| 7 |
+
n_fast_layer: 4
|
| 8 |
+
n_head: 12
|
| 9 |
+
dim: 768
|
| 10 |
+
rope_base: 10000
|
| 11 |
+
norm_eps: 1e-5
|
| 12 |
+
num_codebooks: 2 # input/output codebook size
|
| 13 |
+
codebook_size: 1032 # codebook size 1024 + 2 special tokens
|
fish_speech/configs/model/naive_2_codebook_small.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
|
| 2 |
+
config:
|
| 3 |
+
_target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
|
| 4 |
+
max_seq_len: ${max_length}
|
| 5 |
+
vocab_size: 36408
|
| 6 |
+
n_layer: 12
|
| 7 |
+
n_head: 12
|
| 8 |
+
dim: 768
|
| 9 |
+
rope_base: 10000
|
| 10 |
+
norm_eps: 1e-5
|
| 11 |
+
num_codebooks: 2 # input/output codebook size
|
| 12 |
+
codebook_size: 1032 # codebook size 1024 + 2 special tokens
|
fish_speech/configs/text2semantic_finetune.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base
|
| 3 |
+
- model@model.model: dual_ar_2_codebook_small
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
project: text2semantic_finetune_dual_ar
|
| 7 |
+
max_length: 2048
|
| 8 |
+
ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
|
| 9 |
+
resume_weights_only: true
|
| 10 |
+
|
| 11 |
+
# Lightning Trainer
|
| 12 |
+
trainer:
|
| 13 |
+
accumulate_grad_batches: 1
|
| 14 |
+
gradient_clip_val: 1.0
|
| 15 |
+
gradient_clip_algorithm: 'norm'
|
| 16 |
+
max_steps: 1000
|
| 17 |
+
precision: bf16-true
|
| 18 |
+
limit_val_batches: 10
|
| 19 |
+
val_check_interval: 100
|
| 20 |
+
|
| 21 |
+
# Dataset Configuration
|
| 22 |
+
tokenizer:
|
| 23 |
+
_target_: transformers.AutoTokenizer.from_pretrained
|
| 24 |
+
pretrained_model_name_or_path: fishaudio/fish-speech-1
|
| 25 |
+
|
| 26 |
+
# Dataset Configuration
|
| 27 |
+
train_dataset:
|
| 28 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 29 |
+
proto_files:
|
| 30 |
+
- data/protos
|
| 31 |
+
tokenizer: ${tokenizer}
|
| 32 |
+
max_length: ${max_length}
|
| 33 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 34 |
+
use_speaker: false
|
| 35 |
+
|
| 36 |
+
val_dataset:
|
| 37 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 38 |
+
proto_files:
|
| 39 |
+
- data/protos
|
| 40 |
+
tokenizer: ${tokenizer}
|
| 41 |
+
max_length: ${max_length}
|
| 42 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 43 |
+
use_speaker: false
|
| 44 |
+
|
| 45 |
+
data:
|
| 46 |
+
_target_: fish_speech.datasets.text.TextDataModule
|
| 47 |
+
train_dataset: ${train_dataset}
|
| 48 |
+
val_dataset: ${val_dataset}
|
| 49 |
+
num_workers: 4
|
| 50 |
+
batch_size: 8
|
| 51 |
+
tokenizer: ${tokenizer}
|
| 52 |
+
max_length: ${max_length}
|
| 53 |
+
|
| 54 |
+
# Model Configuration
|
| 55 |
+
model:
|
| 56 |
+
_target_: fish_speech.models.text2semantic.TextToSemantic
|
| 57 |
+
model: {}
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_target_: torch.optim.AdamW
|
| 61 |
+
_partial_: true
|
| 62 |
+
lr: 1e-5
|
| 63 |
+
weight_decay: 0
|
| 64 |
+
betas: [0.9, 0.95]
|
| 65 |
+
eps: 1e-5
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
| 69 |
+
_partial_: true
|
| 70 |
+
lr_lambda:
|
| 71 |
+
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
| 72 |
+
_partial_: true
|
| 73 |
+
num_warmup_steps: 100
|
| 74 |
+
num_training_steps: ${trainer.max_steps}
|
| 75 |
+
|
| 76 |
+
# Callbacks
|
| 77 |
+
callbacks:
|
| 78 |
+
model_checkpoint:
|
| 79 |
+
every_n_train_steps: 100
|
fish_speech/configs/text2semantic_finetune_lora.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- text2semantic_finetune
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
project: text2semantic_finetune_dual_ar_lora
|
| 6 |
+
|
| 7 |
+
# Model Configuration
|
| 8 |
+
model:
|
| 9 |
+
save_lora_only: true
|
| 10 |
+
lora_config:
|
| 11 |
+
_target_: fish_speech.models.text2semantic.lit_module.LoraConfig
|
| 12 |
+
r: 8
|
| 13 |
+
lora_alpha: 16
|
fish_speech/configs/text2semantic_pretrain.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base
|
| 3 |
+
- model@model.model: dual_ar_2_codebook_small
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
project: text2semantic_pretrain_dual_ar_debug
|
| 7 |
+
max_length: 2048
|
| 8 |
+
|
| 9 |
+
# Lightning Trainer
|
| 10 |
+
trainer:
|
| 11 |
+
accumulate_grad_batches: 1
|
| 12 |
+
gradient_clip_val: 1.0
|
| 13 |
+
gradient_clip_algorithm: 'norm'
|
| 14 |
+
max_steps: 1_000_000
|
| 15 |
+
precision: bf16-true
|
| 16 |
+
limit_val_batches: 10
|
| 17 |
+
|
| 18 |
+
# Dataset Configuration
|
| 19 |
+
tokenizer:
|
| 20 |
+
_target_: transformers.AutoTokenizer.from_pretrained
|
| 21 |
+
pretrained_model_name_or_path: fishaudio/fish-speech-1
|
| 22 |
+
|
| 23 |
+
# Dataset Configuration
|
| 24 |
+
train_dataset:
|
| 25 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 26 |
+
proto_files:
|
| 27 |
+
- data/protos/train
|
| 28 |
+
tokenizer: ${tokenizer}
|
| 29 |
+
max_length: ${max_length}
|
| 30 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 31 |
+
use_speaker: false
|
| 32 |
+
interactive_prob: 0.5
|
| 33 |
+
|
| 34 |
+
val_dataset:
|
| 35 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 36 |
+
proto_files:
|
| 37 |
+
- data/protos/test
|
| 38 |
+
tokenizer: ${tokenizer}
|
| 39 |
+
max_length: ${max_length}
|
| 40 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 41 |
+
use_speaker: false
|
| 42 |
+
interactive_prob: 0.5
|
| 43 |
+
|
| 44 |
+
data:
|
| 45 |
+
_target_: fish_speech.datasets.text.TextDataModule
|
| 46 |
+
train_dataset: ${train_dataset}
|
| 47 |
+
val_dataset: ${val_dataset}
|
| 48 |
+
num_workers: 4
|
| 49 |
+
batch_size: 8
|
| 50 |
+
tokenizer: ${tokenizer}
|
| 51 |
+
max_length: ${max_length}
|
| 52 |
+
|
| 53 |
+
# Model Configuration
|
| 54 |
+
model:
|
| 55 |
+
_target_: fish_speech.models.text2semantic.TextToSemantic
|
| 56 |
+
model: {}
|
| 57 |
+
|
| 58 |
+
optimizer:
|
| 59 |
+
_target_: torch.optim.AdamW
|
| 60 |
+
_partial_: true
|
| 61 |
+
lr: 3e-4
|
| 62 |
+
weight_decay: 0.01
|
| 63 |
+
betas: [0.9, 0.95]
|
| 64 |
+
eps: 1e-5
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
| 68 |
+
_partial_: true
|
| 69 |
+
lr_lambda:
|
| 70 |
+
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
| 71 |
+
_partial_: true
|
| 72 |
+
num_warmup_steps: 2000
|
| 73 |
+
num_training_steps: ${trainer.max_steps}
|
| 74 |
+
final_lr_ratio: 0.1
|
fish_speech/configs/text2semantic_sft.yaml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base
|
| 3 |
+
- model@model.model: dual_ar_8_codebook_small
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
project: text2semantic_sft_medium_dual_ar
|
| 7 |
+
max_length: 4096
|
| 8 |
+
ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
|
| 9 |
+
resume_weights_only: true
|
| 10 |
+
|
| 11 |
+
# Lightning Trainer
|
| 12 |
+
trainer:
|
| 13 |
+
accumulate_grad_batches: 1
|
| 14 |
+
gradient_clip_val: 1.0
|
| 15 |
+
gradient_clip_algorithm: 'norm'
|
| 16 |
+
max_steps: 10_000
|
| 17 |
+
precision: bf16-true
|
| 18 |
+
limit_val_batches: 10
|
| 19 |
+
val_check_interval: 500
|
| 20 |
+
|
| 21 |
+
# Dataset Configuration
|
| 22 |
+
tokenizer:
|
| 23 |
+
_target_: transformers.AutoTokenizer.from_pretrained
|
| 24 |
+
pretrained_model_name_or_path: fishaudio/speech-lm-v1
|
| 25 |
+
|
| 26 |
+
# Dataset Configuration
|
| 27 |
+
train_dataset:
|
| 28 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 29 |
+
use_data_server: false
|
| 30 |
+
proto_files:
|
| 31 |
+
- data/protos/sft/train_Genshin.protos
|
| 32 |
+
- data/protos/sft/sft.protos
|
| 33 |
+
tokenizer: ${tokenizer}
|
| 34 |
+
max_length: ${max_length}
|
| 35 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 36 |
+
use_speaker: false
|
| 37 |
+
phones_prob: 0.5
|
| 38 |
+
interactive_prob: 0.5
|
| 39 |
+
|
| 40 |
+
val_dataset:
|
| 41 |
+
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
| 42 |
+
use_data_server: false
|
| 43 |
+
proto_files:
|
| 44 |
+
- data/protos/sft/val_Genshin.protos
|
| 45 |
+
tokenizer: ${tokenizer}
|
| 46 |
+
max_length: ${max_length}
|
| 47 |
+
num_codebooks: ${model.model.config.num_codebooks}
|
| 48 |
+
use_speaker: false
|
| 49 |
+
phones_prob: 0.5
|
| 50 |
+
interactive_prob: 0.5
|
| 51 |
+
|
| 52 |
+
data:
|
| 53 |
+
_target_: fish_speech.datasets.text.TextDataModule
|
| 54 |
+
train_dataset: ${train_dataset}
|
| 55 |
+
val_dataset: ${val_dataset}
|
| 56 |
+
num_workers: 4
|
| 57 |
+
batch_size: 8
|
| 58 |
+
tokenizer: ${tokenizer}
|
| 59 |
+
max_length: ${max_length}
|
| 60 |
+
|
| 61 |
+
# Model Configuration
|
| 62 |
+
model:
|
| 63 |
+
_target_: fish_speech.models.text2semantic.TextToSemantic
|
| 64 |
+
model: {}
|
| 65 |
+
|
| 66 |
+
optimizer:
|
| 67 |
+
_target_: torch.optim.AdamW
|
| 68 |
+
_partial_: true
|
| 69 |
+
lr: 4e-5
|
| 70 |
+
weight_decay: 0
|
| 71 |
+
betas: [0.9, 0.95]
|
| 72 |
+
eps: 1e-5
|
| 73 |
+
|
| 74 |
+
lr_scheduler:
|
| 75 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
| 76 |
+
_partial_: true
|
| 77 |
+
lr_lambda:
|
| 78 |
+
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
| 79 |
+
_partial_: true
|
| 80 |
+
num_warmup_steps: 100
|
| 81 |
+
num_training_steps: ${trainer.max_steps}
|
| 82 |
+
final_lr_ratio: 0
|
| 83 |
+
|
| 84 |
+
callbacks:
|
| 85 |
+
model_checkpoint:
|
| 86 |
+
every_n_train_steps: 1000
|
| 87 |
+
save_top_k: 10
|
fish_speech/configs/vqgan_finetune.yaml
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
project: vq-gan-finetune
|
| 6 |
+
ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
|
| 7 |
+
resume_weights_only: true
|
| 8 |
+
|
| 9 |
+
# Lightning Trainer
|
| 10 |
+
trainer:
|
| 11 |
+
accelerator: gpu
|
| 12 |
+
devices: auto
|
| 13 |
+
precision: bf16-mixed
|
| 14 |
+
max_steps: 100_000
|
| 15 |
+
val_check_interval: 5000
|
| 16 |
+
strategy: ddp_find_unused_parameters_true
|
| 17 |
+
|
| 18 |
+
sample_rate: 44100
|
| 19 |
+
hop_length: 512
|
| 20 |
+
num_mels: 128
|
| 21 |
+
n_fft: 2048
|
| 22 |
+
win_length: 2048
|
| 23 |
+
freeze_encoder: true
|
| 24 |
+
|
| 25 |
+
# Dataset Configuration
|
| 26 |
+
train_dataset:
|
| 27 |
+
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
| 28 |
+
filelist: data/filelist.train.txt
|
| 29 |
+
sample_rate: ${sample_rate}
|
| 30 |
+
hop_length: ${hop_length}
|
| 31 |
+
slice_frames: 512
|
| 32 |
+
|
| 33 |
+
val_dataset:
|
| 34 |
+
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
| 35 |
+
filelist: data/filelist.val.txt
|
| 36 |
+
sample_rate: ${sample_rate}
|
| 37 |
+
hop_length: ${hop_length}
|
| 38 |
+
|
| 39 |
+
data:
|
| 40 |
+
_target_: fish_speech.datasets.vqgan.VQGANDataModule
|
| 41 |
+
train_dataset: ${train_dataset}
|
| 42 |
+
val_dataset: ${val_dataset}
|
| 43 |
+
num_workers: 4
|
| 44 |
+
batch_size: 16
|
| 45 |
+
val_batch_size: 16
|
| 46 |
+
|
| 47 |
+
# Model Configuration
|
| 48 |
+
model:
|
| 49 |
+
_target_: fish_speech.models.vqgan.VQGAN
|
| 50 |
+
|
| 51 |
+
sampling_rate: ${sample_rate}
|
| 52 |
+
weight_adv: 0.2
|
| 53 |
+
weight_vq: 1.0
|
| 54 |
+
weight_mel: 1.0
|
| 55 |
+
freeze_encoder: false
|
| 56 |
+
|
| 57 |
+
encoder:
|
| 58 |
+
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
| 59 |
+
input_channels: ${num_mels}
|
| 60 |
+
residual_channels: 768
|
| 61 |
+
residual_layers: 20
|
| 62 |
+
dilation_cycle: 4
|
| 63 |
+
|
| 64 |
+
quantizer:
|
| 65 |
+
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
| 66 |
+
input_dim: 768
|
| 67 |
+
n_codebooks: 1
|
| 68 |
+
n_groups: 2
|
| 69 |
+
levels: [8, 5, 5, 5]
|
| 70 |
+
|
| 71 |
+
decoder:
|
| 72 |
+
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
| 73 |
+
output_channels: ${num_mels}
|
| 74 |
+
residual_channels: 768
|
| 75 |
+
residual_layers: 20
|
| 76 |
+
dilation_cycle: 4
|
| 77 |
+
condition_channels: 768
|
| 78 |
+
|
| 79 |
+
discriminator:
|
| 80 |
+
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
|
| 81 |
+
|
| 82 |
+
vocoder:
|
| 83 |
+
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
|
| 84 |
+
ckpt_path: null # You may download the pretrained vocoder and set the path here
|
| 85 |
+
|
| 86 |
+
encode_mel_transform:
|
| 87 |
+
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
| 88 |
+
sample_rate: ${sample_rate}
|
| 89 |
+
n_fft: ${n_fft}
|
| 90 |
+
hop_length: ${hop_length}
|
| 91 |
+
win_length: ${win_length}
|
| 92 |
+
n_mels: ${num_mels}
|
| 93 |
+
f_min: 0.0
|
| 94 |
+
f_max: 8000.0
|
| 95 |
+
|
| 96 |
+
gt_mel_transform:
|
| 97 |
+
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
| 98 |
+
sample_rate: ${sample_rate}
|
| 99 |
+
n_fft: ${n_fft}
|
| 100 |
+
hop_length: ${hop_length}
|
| 101 |
+
win_length: ${win_length}
|
| 102 |
+
n_mels: ${num_mels}
|
| 103 |
+
|
| 104 |
+
optimizer:
|
| 105 |
+
_target_: torch.optim.AdamW
|
| 106 |
+
_partial_: true
|
| 107 |
+
lr: 4e-5
|
| 108 |
+
betas: [0.8, 0.99]
|
| 109 |
+
eps: 1e-5
|
| 110 |
+
weight_decay: 0.01
|
| 111 |
+
|
| 112 |
+
lr_scheduler:
|
| 113 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
| 114 |
+
_partial_: true
|
| 115 |
+
lr_lambda:
|
| 116 |
+
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
| 117 |
+
_partial_: true
|
| 118 |
+
num_warmup_steps: 100
|
| 119 |
+
num_training_steps: ${trainer.max_steps}
|
| 120 |
+
final_lr_ratio: 0
|
| 121 |
+
|
| 122 |
+
callbacks:
|
| 123 |
+
model_summary:
|
| 124 |
+
_target_: lightning.pytorch.callbacks.ModelSummary
|
| 125 |
+
max_depth: 1
|
| 126 |
+
|
| 127 |
+
model_checkpoint:
|
| 128 |
+
every_n_train_steps: ${trainer.val_check_interval}
|
| 129 |
+
|
| 130 |
+
grad_norm_monitor:
|
| 131 |
+
sub_module:
|
| 132 |
+
- encoder
|
| 133 |
+
- decoder
|
| 134 |
+
- quantizer
|
| 135 |
+
- discriminator
|
fish_speech/configs/vqgan_pretrain.yaml
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
project: vq-gan-pretrain
|
| 6 |
+
|
| 7 |
+
# Lightning Trainer
|
| 8 |
+
trainer:
|
| 9 |
+
accelerator: gpu
|
| 10 |
+
devices: auto
|
| 11 |
+
precision: bf16-mixed
|
| 12 |
+
max_steps: 1_000_000
|
| 13 |
+
val_check_interval: 5000
|
| 14 |
+
strategy: ddp_find_unused_parameters_true
|
| 15 |
+
|
| 16 |
+
sample_rate: 44100
|
| 17 |
+
hop_length: 512
|
| 18 |
+
num_mels: 128
|
| 19 |
+
n_fft: 2048
|
| 20 |
+
win_length: 2048
|
| 21 |
+
|
| 22 |
+
# Dataset Configuration
|
| 23 |
+
train_dataset:
|
| 24 |
+
_target_: torch.utils.data.ConcatDataset
|
| 25 |
+
datasets:
|
| 26 |
+
- _target_: fish_speech.datasets.vqgan.VQGANDataset
|
| 27 |
+
filelist: data/gigaspeech/vq_train_filelist.txt
|
| 28 |
+
sample_rate: ${sample_rate}
|
| 29 |
+
hop_length: ${hop_length}
|
| 30 |
+
slice_frames: 512
|
| 31 |
+
- _target_: fish_speech.datasets.vqgan.VQGANDataset
|
| 32 |
+
filelist: data/sft/vq_train_filelist.txt
|
| 33 |
+
sample_rate: ${sample_rate}
|
| 34 |
+
hop_length: ${hop_length}
|
| 35 |
+
slice_frames: 512
|
| 36 |
+
|
| 37 |
+
val_dataset:
|
| 38 |
+
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
| 39 |
+
filelist: data/sft/vq_val_filelist.txt
|
| 40 |
+
sample_rate: ${sample_rate}
|
| 41 |
+
hop_length: ${hop_length}
|
| 42 |
+
|
| 43 |
+
data:
|
| 44 |
+
_target_: fish_speech.datasets.vqgan.VQGANDataModule
|
| 45 |
+
train_dataset: ${train_dataset}
|
| 46 |
+
val_dataset: ${val_dataset}
|
| 47 |
+
num_workers: 4
|
| 48 |
+
batch_size: 32
|
| 49 |
+
val_batch_size: 32
|
| 50 |
+
|
| 51 |
+
# Model Configuration
|
| 52 |
+
model:
|
| 53 |
+
_target_: fish_speech.models.vqgan.VQGAN
|
| 54 |
+
|
| 55 |
+
sampling_rate: ${sample_rate}
|
| 56 |
+
weight_adv: 0.2
|
| 57 |
+
weight_vq: 1.0
|
| 58 |
+
weight_mel: 1.0
|
| 59 |
+
freeze_encoder: false
|
| 60 |
+
|
| 61 |
+
encoder:
|
| 62 |
+
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
| 63 |
+
input_channels: ${num_mels}
|
| 64 |
+
residual_channels: 768
|
| 65 |
+
residual_layers: 20
|
| 66 |
+
dilation_cycle: 4
|
| 67 |
+
|
| 68 |
+
quantizer:
|
| 69 |
+
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
| 70 |
+
input_dim: 768
|
| 71 |
+
n_codebooks: 1
|
| 72 |
+
n_groups: 2
|
| 73 |
+
levels: [8, 5, 5, 5]
|
| 74 |
+
|
| 75 |
+
decoder:
|
| 76 |
+
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
| 77 |
+
output_channels: ${num_mels}
|
| 78 |
+
residual_channels: 768
|
| 79 |
+
residual_layers: 20
|
| 80 |
+
dilation_cycle: 4
|
| 81 |
+
condition_channels: 768
|
| 82 |
+
|
| 83 |
+
discriminator:
|
| 84 |
+
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
|
| 85 |
+
|
| 86 |
+
vocoder:
|
| 87 |
+
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
|
| 88 |
+
ckpt_path: null # You may download the pretrained vocoder and set the path here
|
| 89 |
+
|
| 90 |
+
encode_mel_transform:
|
| 91 |
+
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
| 92 |
+
sample_rate: ${sample_rate}
|
| 93 |
+
n_fft: ${n_fft}
|
| 94 |
+
hop_length: ${hop_length}
|
| 95 |
+
win_length: ${win_length}
|
| 96 |
+
n_mels: ${num_mels}
|
| 97 |
+
f_min: 0.0
|
| 98 |
+
f_max: 8000.0
|
| 99 |
+
|
| 100 |
+
gt_mel_transform:
|
| 101 |
+
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
| 102 |
+
sample_rate: ${sample_rate}
|
| 103 |
+
n_fft: ${n_fft}
|
| 104 |
+
hop_length: ${hop_length}
|
| 105 |
+
win_length: ${win_length}
|
| 106 |
+
n_mels: ${num_mels}
|
| 107 |
+
|
| 108 |
+
optimizer:
|
| 109 |
+
_target_: torch.optim.AdamW
|
| 110 |
+
_partial_: true
|
| 111 |
+
lr: 1e-4
|
| 112 |
+
betas: [0.8, 0.99]
|
| 113 |
+
eps: 1e-5
|
| 114 |
+
weight_decay: 0.01
|
| 115 |
+
|
| 116 |
+
lr_scheduler:
|
| 117 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
| 118 |
+
_partial_: true
|
| 119 |
+
lr_lambda:
|
| 120 |
+
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
| 121 |
+
_partial_: true
|
| 122 |
+
num_warmup_steps: 100
|
| 123 |
+
num_training_steps: ${trainer.max_steps}
|
| 124 |
+
final_lr_ratio: 0
|
| 125 |
+
|
| 126 |
+
callbacks:
|
| 127 |
+
model_summary:
|
| 128 |
+
_target_: lightning.pytorch.callbacks.ModelSummary
|
| 129 |
+
max_depth: 1
|
| 130 |
+
|
| 131 |
+
model_checkpoint:
|
| 132 |
+
every_n_train_steps: ${trainer.val_check_interval}
|
| 133 |
+
|
| 134 |
+
grad_norm_monitor:
|
| 135 |
+
sub_module:
|
| 136 |
+
- encoder
|
| 137 |
+
- decoder
|
| 138 |
+
- quantizer
|
| 139 |
+
- discriminator
|
fish_speech/datasets/protos/text-data.proto
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
syntax = "proto3";
|
| 2 |
+
|
| 3 |
+
package text_data;
|
| 4 |
+
|
| 5 |
+
message Semantics {
|
| 6 |
+
repeated uint32 values = 1;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
message Sentence {
|
| 10 |
+
repeated string texts = 1;
|
| 11 |
+
repeated Semantics semantics = 3;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
message TextData {
|
| 15 |
+
string source = 1;
|
| 16 |
+
string name = 2;
|
| 17 |
+
repeated Sentence sentences = 4;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
message SampledData {
|
| 21 |
+
string source = 1;
|
| 22 |
+
string name = 2;
|
| 23 |
+
repeated Sentence samples = 3;
|
| 24 |
+
}
|
fish_speech/datasets/protos/text_data_pb2.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: text-data.proto
|
| 4 |
+
# Protobuf Python Version: 4.25.1
|
| 5 |
+
"""Generated protocol buffer code."""
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
from google.protobuf.internal import builder as _builder
|
| 10 |
+
|
| 11 |
+
# @@protoc_insertion_point(imports)
|
| 12 |
+
|
| 13 |
+
_sym_db = _symbol_database.Default()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
| 17 |
+
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
_globals = globals()
|
| 21 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 22 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
|
| 23 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 24 |
+
DESCRIPTOR._options = None
|
| 25 |
+
_globals["_SEMANTICS"]._serialized_start = 30
|
| 26 |
+
_globals["_SEMANTICS"]._serialized_end = 57
|
| 27 |
+
_globals["_SENTENCE"]._serialized_start = 59
|
| 28 |
+
_globals["_SENTENCE"]._serialized_end = 125
|
| 29 |
+
_globals["_TEXTDATA"]._serialized_start = 127
|
| 30 |
+
_globals["_TEXTDATA"]._serialized_end = 207
|
| 31 |
+
_globals["_SAMPLEDDATA"]._serialized_start = 209
|
| 32 |
+
_globals["_SAMPLEDDATA"]._serialized_end = 290
|
| 33 |
+
# @@protoc_insertion_point(module_scope)
|
fish_speech/datasets/protos/text_data_stream.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import struct
|
| 2 |
+
|
| 3 |
+
from .text_data_pb2 import TextData
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_pb_stream(f):
|
| 7 |
+
while True:
|
| 8 |
+
buf = f.read(4)
|
| 9 |
+
if len(buf) == 0:
|
| 10 |
+
break
|
| 11 |
+
size = struct.unpack("I", buf)[0]
|
| 12 |
+
buf = f.read(size)
|
| 13 |
+
text_data = TextData()
|
| 14 |
+
text_data.ParseFromString(buf)
|
| 15 |
+
yield text_data
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def write_pb_stream(f, text_data):
|
| 19 |
+
buf = text_data.SerializeToString()
|
| 20 |
+
f.write(struct.pack("I", len(buf)))
|
| 21 |
+
f.write(buf)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pack_pb_stream(text_data):
|
| 25 |
+
buf = text_data.SerializeToString()
|
| 26 |
+
return struct.pack("I", len(buf)) + buf
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def split_pb_stream(f):
|
| 30 |
+
while True:
|
| 31 |
+
head = f.read(4)
|
| 32 |
+
if len(head) == 0:
|
| 33 |
+
break
|
| 34 |
+
size = struct.unpack("I", head)[0]
|
| 35 |
+
buf = f.read(size)
|
| 36 |
+
yield head + buf
|
fish_speech/datasets/text.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from itertools import chain
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from random import Random
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
import grpc
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pyarrow.parquet as pq
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from datasets.download.streaming_download_manager import xopen
|
| 14 |
+
from huggingface_hub import HfApi
|
| 15 |
+
from lightning import LightningDataModule
|
| 16 |
+
from torch.distributed import get_rank, get_world_size, is_initialized
|
| 17 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 18 |
+
from transformers import AutoTokenizer
|
| 19 |
+
|
| 20 |
+
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
| 21 |
+
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
| 22 |
+
from fish_speech.text.clean import clean_text
|
| 23 |
+
from fish_speech.utils import RankedLogger
|
| 24 |
+
from fish_speech.utils.braceexpand import braceexpand
|
| 25 |
+
|
| 26 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 27 |
+
|
| 28 |
+
CODEBOOK_PAD_TOKEN_ID = 0
|
| 29 |
+
CODEBOOK_EOS_TOKEN_ID = 1
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def split_by_rank_worker(files):
|
| 33 |
+
# We need to know the total number of devices
|
| 34 |
+
# to split the data properly
|
| 35 |
+
|
| 36 |
+
total_devices = 1
|
| 37 |
+
if is_initialized():
|
| 38 |
+
total_devices = get_world_size()
|
| 39 |
+
|
| 40 |
+
worker_info = get_worker_info()
|
| 41 |
+
if worker_info is not None:
|
| 42 |
+
total_devices *= worker_info.num_workers
|
| 43 |
+
|
| 44 |
+
if len(files) < total_devices:
|
| 45 |
+
# Repeat the files N times to match the number of devices
|
| 46 |
+
files = files * (total_devices // len(files) + 1)
|
| 47 |
+
|
| 48 |
+
# DDP
|
| 49 |
+
if is_initialized():
|
| 50 |
+
files = files[get_rank() :: get_world_size()]
|
| 51 |
+
|
| 52 |
+
# Split by worker
|
| 53 |
+
if worker_info is not None:
|
| 54 |
+
files = files[worker_info.id :: worker_info.num_workers]
|
| 55 |
+
|
| 56 |
+
return files
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class StreamTextDataset(IterableDataset):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
files: Optional[Union[list[str], str]] = None,
|
| 63 |
+
prefix: Optional[str] = None,
|
| 64 |
+
seed: int = 42,
|
| 65 |
+
parquet_batch_size: int = 10000,
|
| 66 |
+
repo: str = "uonlp/CulturaX",
|
| 67 |
+
max_length: int = 1024,
|
| 68 |
+
tokenizer: AutoTokenizer = None,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.seed = seed
|
| 73 |
+
self.parquet_batch_size = parquet_batch_size
|
| 74 |
+
self.repo = repo
|
| 75 |
+
self.max_length = max_length
|
| 76 |
+
self.tokenizer = tokenizer
|
| 77 |
+
|
| 78 |
+
if files is None and prefix is None:
|
| 79 |
+
raise ValueError("Either files or prefix must be specified")
|
| 80 |
+
|
| 81 |
+
if prefix is not None:
|
| 82 |
+
files = HfApi().list_repo_files(repo, repo_type="dataset")
|
| 83 |
+
files = [
|
| 84 |
+
f for f in files if f.startswith(prefix) and f.endswith(".parquet")
|
| 85 |
+
]
|
| 86 |
+
log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
|
| 87 |
+
else:
|
| 88 |
+
if isinstance(files, str):
|
| 89 |
+
files = [files]
|
| 90 |
+
|
| 91 |
+
files = list(chain.from_iterable(map(braceexpand, files)))
|
| 92 |
+
log.info(f"Expanded {len(files)} files in {repo}")
|
| 93 |
+
|
| 94 |
+
# Get sharded files
|
| 95 |
+
self.files = sorted(files)
|
| 96 |
+
Random(seed).shuffle(self.files)
|
| 97 |
+
|
| 98 |
+
def __iter__(self):
|
| 99 |
+
files = split_by_rank_worker(self.files)
|
| 100 |
+
random.shuffle(files)
|
| 101 |
+
|
| 102 |
+
for filename in files:
|
| 103 |
+
try:
|
| 104 |
+
yield from self.parse_data(filename)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
log.exception(f"Failed to parse {filename}: {e}")
|
| 107 |
+
|
| 108 |
+
def parse_data(self, filename: str):
|
| 109 |
+
for data in self.parse_data_internal(filename):
|
| 110 |
+
text = data["text"]
|
| 111 |
+
|
| 112 |
+
# encode
|
| 113 |
+
tokens = self.tokenizer.encode(
|
| 114 |
+
text,
|
| 115 |
+
add_special_tokens=False,
|
| 116 |
+
truncation=False,
|
| 117 |
+
max_length=10**6,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Random choice self.max_length
|
| 121 |
+
if len(tokens) > self.max_length:
|
| 122 |
+
start = random.randint(0, len(tokens) - self.max_length)
|
| 123 |
+
tokens = tokens[start : start + self.max_length - 1]
|
| 124 |
+
|
| 125 |
+
tokens = (
|
| 126 |
+
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
| 127 |
+
)
|
| 128 |
+
# Pad dims
|
| 129 |
+
placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
|
| 130 |
+
|
| 131 |
+
tokens = torch.concat(
|
| 132 |
+
[
|
| 133 |
+
torch.tensor([tokens], dtype=torch.long),
|
| 134 |
+
placeholder_multi_codebook,
|
| 135 |
+
],
|
| 136 |
+
dim=0,
|
| 137 |
+
)
|
| 138 |
+
labels = tokens.clone()
|
| 139 |
+
tokens = tokens[:, :-1]
|
| 140 |
+
labels = labels[:, 1:]
|
| 141 |
+
labels[1:] = -100 # remove all placeholders
|
| 142 |
+
|
| 143 |
+
yield {"tokens": tokens, "labels": labels}
|
| 144 |
+
|
| 145 |
+
def parse_data_internal(self, filename: str):
|
| 146 |
+
url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
|
| 147 |
+
|
| 148 |
+
with xopen(url, mode="rb") as stream:
|
| 149 |
+
parquet_file = pq.ParquetFile(stream)
|
| 150 |
+
|
| 151 |
+
for batch in parquet_file.iter_batches(
|
| 152 |
+
batch_size=self.parquet_batch_size, columns=["text"]
|
| 153 |
+
):
|
| 154 |
+
# In-batch shuffling
|
| 155 |
+
texts = [{"text": text.as_py()} for text in batch["text"]]
|
| 156 |
+
random.shuffle(texts)
|
| 157 |
+
yield from texts
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class AutoAugTextDataset(IterableDataset):
|
| 161 |
+
"""
|
| 162 |
+
Auto Augment Dataset by Speaker
|
| 163 |
+
|
| 164 |
+
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
| 165 |
+
2. Automatically normalize the text
|
| 166 |
+
|
| 167 |
+
For interactive mode, we use the following format (multiple sequences):
|
| 168 |
+
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
| 169 |
+
|
| 170 |
+
For non-interactive mode, we use the following format (one long sequence):
|
| 171 |
+
<s> [INST] text [/INST] ... </s>
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
proto_files: list[str],
|
| 177 |
+
seed: int = 42,
|
| 178 |
+
interactive_prob: float = 0.5,
|
| 179 |
+
max_length: int = 1024,
|
| 180 |
+
tokenizer: AutoTokenizer = None,
|
| 181 |
+
use_speaker: bool = True,
|
| 182 |
+
causual: bool = True,
|
| 183 |
+
use_negative_samples: bool = False,
|
| 184 |
+
num_codebooks: Optional[int] = None,
|
| 185 |
+
):
|
| 186 |
+
"""
|
| 187 |
+
Args:
|
| 188 |
+
proto_files: proto buf files if using local data
|
| 189 |
+
seed: random seed
|
| 190 |
+
interactive_prob: probability to use interactive mode
|
| 191 |
+
max_length: max length of the text
|
| 192 |
+
tokenizer: tokenizer
|
| 193 |
+
use_speaker: include speaker information in the prompt
|
| 194 |
+
causual: use causual sampling when using local data, disable will lead to random sampling
|
| 195 |
+
use_negative_samples: generate negative samples
|
| 196 |
+
num_codebooks: number of codebooks, if None, it will be automatically detected
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
super().__init__()
|
| 200 |
+
|
| 201 |
+
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
| 202 |
+
|
| 203 |
+
self.seed = seed
|
| 204 |
+
self.max_length = max_length
|
| 205 |
+
self.tokenizer = tokenizer
|
| 206 |
+
self.interactive_prob = interactive_prob
|
| 207 |
+
self.use_speaker = use_speaker
|
| 208 |
+
self.proto_files = proto_files
|
| 209 |
+
self.causual = causual
|
| 210 |
+
self.use_negative_samples = use_negative_samples
|
| 211 |
+
self.num_codebooks = num_codebooks
|
| 212 |
+
|
| 213 |
+
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
| 214 |
+
self.groups = None
|
| 215 |
+
|
| 216 |
+
def init_mock_data_server(self):
|
| 217 |
+
if self.groups is not None:
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
# Expand the proto files
|
| 221 |
+
expanded_proto_files = []
|
| 222 |
+
for filename in self.proto_files:
|
| 223 |
+
for i in braceexpand(filename):
|
| 224 |
+
i = Path(i)
|
| 225 |
+
if i.is_file():
|
| 226 |
+
expanded_proto_files.append(i)
|
| 227 |
+
elif i.is_dir():
|
| 228 |
+
expanded_proto_files.extend(i.rglob("*.proto"))
|
| 229 |
+
expanded_proto_files.extend(i.rglob("*.protos"))
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError(f"{i} is not a file or directory")
|
| 232 |
+
|
| 233 |
+
expanded_proto_files = sorted(expanded_proto_files)
|
| 234 |
+
Random(self.seed).shuffle(expanded_proto_files)
|
| 235 |
+
|
| 236 |
+
self.groups = []
|
| 237 |
+
shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
| 238 |
+
log.info(
|
| 239 |
+
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
count = 0
|
| 243 |
+
for filename in shard_proto_files:
|
| 244 |
+
with open(filename, "rb") as f:
|
| 245 |
+
for text_data in read_pb_stream(f):
|
| 246 |
+
self.groups.append(text_data)
|
| 247 |
+
count += 1
|
| 248 |
+
|
| 249 |
+
log.info(f"Read total {count} groups of data")
|
| 250 |
+
|
| 251 |
+
# Shuffle the lines
|
| 252 |
+
Random(self.seed).shuffle(self.groups)
|
| 253 |
+
self.group_weights = [len(i.sentences) for i in self.groups]
|
| 254 |
+
|
| 255 |
+
def __iter__(self):
|
| 256 |
+
while True:
|
| 257 |
+
yield self.augment()
|
| 258 |
+
|
| 259 |
+
def tokenize_sentence(self, sentence: str):
|
| 260 |
+
sentence = clean_text(sentence)
|
| 261 |
+
tokens = self.tokenizer.encode(
|
| 262 |
+
f"{sentence}",
|
| 263 |
+
max_length=10**6,
|
| 264 |
+
add_special_tokens=False,
|
| 265 |
+
truncation=False,
|
| 266 |
+
)
|
| 267 |
+
return sentence, len(tokens)
|
| 268 |
+
|
| 269 |
+
def sample_data(self):
|
| 270 |
+
if self.groups is None:
|
| 271 |
+
self.init_mock_data_server()
|
| 272 |
+
|
| 273 |
+
# Shuffle unique lines, estimate that each sample is at least 20 tokens
|
| 274 |
+
num_samples = self.max_length // 20
|
| 275 |
+
|
| 276 |
+
# choice group based on their number of samples
|
| 277 |
+
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
| 278 |
+
|
| 279 |
+
if self.causual:
|
| 280 |
+
# Sample in order
|
| 281 |
+
if num_samples >= len(group.sentences):
|
| 282 |
+
samples = group.sentences
|
| 283 |
+
else:
|
| 284 |
+
begin = random.randint(0, len(group.sentences) - num_samples)
|
| 285 |
+
samples = group.sentences[begin : begin + num_samples]
|
| 286 |
+
else:
|
| 287 |
+
samples = random.choices(
|
| 288 |
+
group.sentences, k=min(num_samples, len(group.sentences))
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return SampledData(
|
| 292 |
+
source=group.source,
|
| 293 |
+
name=group.name,
|
| 294 |
+
samples=samples,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def augment(self):
|
| 298 |
+
# Random sample based on speaker using a truncated normal distribution
|
| 299 |
+
a = torch.tensor([0], dtype=torch.float32)
|
| 300 |
+
torch.nn.init.trunc_normal_(
|
| 301 |
+
a,
|
| 302 |
+
mean=self.max_length // 2,
|
| 303 |
+
std=self.max_length // 4,
|
| 304 |
+
a=10,
|
| 305 |
+
b=self.max_length,
|
| 306 |
+
)
|
| 307 |
+
remaining_tokens = a.long().item() - 4
|
| 308 |
+
|
| 309 |
+
final_text, final_semantic = [], []
|
| 310 |
+
response = self.sample_data()
|
| 311 |
+
if len(response.samples) == 0:
|
| 312 |
+
# Invalid group
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
samples = list(response.samples)
|
| 316 |
+
idx = 0
|
| 317 |
+
use_interactive = random.random() < self.interactive_prob
|
| 318 |
+
|
| 319 |
+
all_tokens, all_labels = [], []
|
| 320 |
+
while remaining_tokens > 0 and len(samples) > 0:
|
| 321 |
+
sentence = samples.pop(0)
|
| 322 |
+
|
| 323 |
+
text = random.choice(sentence.texts)
|
| 324 |
+
text, length = self.tokenize_sentence(text)
|
| 325 |
+
remaining_tokens -= length + len(sentence.semantics[0].values)
|
| 326 |
+
|
| 327 |
+
if use_interactive is False:
|
| 328 |
+
final_text.append(text)
|
| 329 |
+
final_semantic.append(sentence.semantics)
|
| 330 |
+
else:
|
| 331 |
+
# For interactive mode, we only apply speaker for the first sentence
|
| 332 |
+
# [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
| 333 |
+
tokens, labels = self.pack_sentences(
|
| 334 |
+
sentences=[text],
|
| 335 |
+
semantics=[sentence.semantics],
|
| 336 |
+
speaker=response.name if (self.use_speaker and idx == 0) else None,
|
| 337 |
+
add_bos=idx == 0,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
all_tokens.append(tokens)
|
| 341 |
+
all_labels.append(labels)
|
| 342 |
+
|
| 343 |
+
idx += 1
|
| 344 |
+
|
| 345 |
+
if use_interactive is False:
|
| 346 |
+
tokens, labels = self.pack_sentences(
|
| 347 |
+
final_text,
|
| 348 |
+
semantics=final_semantic,
|
| 349 |
+
speaker=response.name if self.use_speaker else None,
|
| 350 |
+
add_bos=True,
|
| 351 |
+
)
|
| 352 |
+
all_tokens.append(tokens)
|
| 353 |
+
all_labels.append(labels)
|
| 354 |
+
|
| 355 |
+
tokens = torch.cat(all_tokens, dim=1)
|
| 356 |
+
labels = torch.cat(all_labels, dim=1)
|
| 357 |
+
|
| 358 |
+
# Verify that the length is correct
|
| 359 |
+
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
| 360 |
+
|
| 361 |
+
# Verify bos token
|
| 362 |
+
assert tokens[0, 0] == self.tokenizer.bos_token_id
|
| 363 |
+
|
| 364 |
+
data = {"tokens": tokens, "labels": labels}
|
| 365 |
+
|
| 366 |
+
if self.use_negative_samples:
|
| 367 |
+
negative_samples = self.generate_negative_samples(all_tokens, all_labels)
|
| 368 |
+
data.update(negative_samples)
|
| 369 |
+
|
| 370 |
+
return data
|
| 371 |
+
|
| 372 |
+
def generate_negative_samples(self, all_tokens, all_labels):
|
| 373 |
+
new_tokens, new_labels = [], []
|
| 374 |
+
|
| 375 |
+
for tokens, labels in zip(all_tokens, all_labels):
|
| 376 |
+
# If all codebooks are not -100, we find where it starts
|
| 377 |
+
start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
|
| 378 |
+
assert (labels[1:, start:] != -100).all() # This shouldn't happen
|
| 379 |
+
|
| 380 |
+
mode = random.choice(["repeat", "lost", "noise"])
|
| 381 |
+
begin = random.randint(start, labels.size(1) - 1)
|
| 382 |
+
end = random.randint(begin, labels.size(1) - 1)
|
| 383 |
+
|
| 384 |
+
if mode == "repeat":
|
| 385 |
+
tokens = torch.cat(
|
| 386 |
+
[
|
| 387 |
+
tokens[:, :begin],
|
| 388 |
+
tokens[:, begin:end],
|
| 389 |
+
tokens[:, begin:end],
|
| 390 |
+
tokens[:, end:],
|
| 391 |
+
],
|
| 392 |
+
dim=1,
|
| 393 |
+
)
|
| 394 |
+
labels = torch.cat(
|
| 395 |
+
[
|
| 396 |
+
labels[:, :begin],
|
| 397 |
+
labels[:, begin:end],
|
| 398 |
+
labels[:, begin:end],
|
| 399 |
+
labels[:, end:],
|
| 400 |
+
],
|
| 401 |
+
dim=1,
|
| 402 |
+
)
|
| 403 |
+
elif mode == "lost":
|
| 404 |
+
tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
|
| 405 |
+
labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
|
| 406 |
+
elif mode == "noise":
|
| 407 |
+
middle_tokens, middle_labels = (
|
| 408 |
+
tokens[:, begin:end],
|
| 409 |
+
labels[:, begin:end],
|
| 410 |
+
)
|
| 411 |
+
random_order0 = torch.randperm(middle_tokens.size(1))
|
| 412 |
+
random_order1 = torch.randperm(middle_tokens.size(1))
|
| 413 |
+
middle_tokens = middle_tokens[:, random_order0]
|
| 414 |
+
middle_labels = middle_labels[:, random_order1]
|
| 415 |
+
tokens = torch.cat(
|
| 416 |
+
[tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
|
| 417 |
+
)
|
| 418 |
+
labels = torch.cat(
|
| 419 |
+
[labels[:, :begin], middle_labels, labels[:, end:]], dim=1
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
new_tokens.append(tokens)
|
| 423 |
+
new_labels.append(labels)
|
| 424 |
+
|
| 425 |
+
tokens = torch.cat(new_tokens, dim=1)
|
| 426 |
+
labels = torch.cat(new_labels, dim=1)
|
| 427 |
+
|
| 428 |
+
# Verify that the length is correct
|
| 429 |
+
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
| 430 |
+
|
| 431 |
+
return {"negative_tokens": tokens, "negative_labels": labels}
|
| 432 |
+
|
| 433 |
+
def pack_sentences(
|
| 434 |
+
self,
|
| 435 |
+
sentences: list[str],
|
| 436 |
+
semantics=list,
|
| 437 |
+
speaker: Optional[str] = None,
|
| 438 |
+
add_bos: bool = True,
|
| 439 |
+
):
|
| 440 |
+
if speaker is not None:
|
| 441 |
+
sentences = [f"[SPK: {speaker}]"] + sentences
|
| 442 |
+
|
| 443 |
+
final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
|
| 444 |
+
final_text = final_text + "<|im_start|>assistant<|im_sep|>"
|
| 445 |
+
|
| 446 |
+
encoded = self.tokenizer.encode(
|
| 447 |
+
final_text,
|
| 448 |
+
add_special_tokens=False,
|
| 449 |
+
truncation=False,
|
| 450 |
+
max_length=10**6,
|
| 451 |
+
)
|
| 452 |
+
semantic_length = sum([len(i[0].values) for i in semantics])
|
| 453 |
+
prompt_length = len(encoded)
|
| 454 |
+
num_codebooks = (
|
| 455 |
+
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
bos_bias = 1 if add_bos else 0
|
| 459 |
+
|
| 460 |
+
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
| 461 |
+
tokens = (
|
| 462 |
+
encoded
|
| 463 |
+
+ [self.semantic_token_id] * semantic_length
|
| 464 |
+
+ self.tokenizer.convert_tokens_to_ids(
|
| 465 |
+
["<|im_end|>", "<|end_of_sequence|>"]
|
| 466 |
+
)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
if add_bos:
|
| 470 |
+
tokens = [self.tokenizer.bos_token_id] + tokens
|
| 471 |
+
|
| 472 |
+
# Codebook bos/padding: 0, eos: 1
|
| 473 |
+
codes = [
|
| 474 |
+
[CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
|
| 475 |
+
for _ in range(num_codebooks)
|
| 476 |
+
]
|
| 477 |
+
for segment in semantics:
|
| 478 |
+
for book_idx, book in zip(range(num_codebooks), segment):
|
| 479 |
+
for j in book.values:
|
| 480 |
+
codes[book_idx].append(int(j) + 2)
|
| 481 |
+
|
| 482 |
+
for book in codes:
|
| 483 |
+
book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
|
| 484 |
+
|
| 485 |
+
tokens = [tokens] + codes
|
| 486 |
+
|
| 487 |
+
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 488 |
+
labels = tokens.clone()
|
| 489 |
+
|
| 490 |
+
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
| 491 |
+
# Since we don't mask out the input tokens, the language modeling still works
|
| 492 |
+
labels[1:, : (prompt_length + bos_bias)] = -100
|
| 493 |
+
|
| 494 |
+
tokens = tokens[:, :-1]
|
| 495 |
+
labels = labels[:, 1:]
|
| 496 |
+
|
| 497 |
+
# Verify the padding is correct, and the last token is eos
|
| 498 |
+
assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
| 499 |
+
assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
|
| 500 |
+
assert labels[0, -1] == self.tokenizer.eos_token_id
|
| 501 |
+
assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
|
| 502 |
+
|
| 503 |
+
return tokens, labels
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@dataclass
|
| 507 |
+
class TextDataCollator:
|
| 508 |
+
tokenizer: AutoTokenizer
|
| 509 |
+
max_length: int = 1024
|
| 510 |
+
|
| 511 |
+
def __call__(self, examples):
|
| 512 |
+
if "negative_tokens" in examples:
|
| 513 |
+
positive_examples = []
|
| 514 |
+
negative_examples = []
|
| 515 |
+
|
| 516 |
+
for i in examples:
|
| 517 |
+
positive_examples.append(
|
| 518 |
+
{
|
| 519 |
+
"tokens": i["tokens"],
|
| 520 |
+
"labels": i["labels"],
|
| 521 |
+
}
|
| 522 |
+
)
|
| 523 |
+
negative_examples.append(
|
| 524 |
+
{
|
| 525 |
+
"tokens": i["negative_tokens"],
|
| 526 |
+
"labels": i["negative_labels"],
|
| 527 |
+
}
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
examples = positive_examples + negative_examples
|
| 531 |
+
|
| 532 |
+
return self.batchify(examples)
|
| 533 |
+
|
| 534 |
+
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
| 535 |
+
tokens, attention_masks, labels = [], [], []
|
| 536 |
+
|
| 537 |
+
# Calculate the max length
|
| 538 |
+
max_tokens_length = 0
|
| 539 |
+
for example in examples:
|
| 540 |
+
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
| 541 |
+
max_tokens_length = min(max_tokens_length, self.max_length)
|
| 542 |
+
|
| 543 |
+
for example in examples:
|
| 544 |
+
_tokens = example[tokens_key][:, :max_tokens_length]
|
| 545 |
+
_labels = example[labels_key][:, :max_tokens_length]
|
| 546 |
+
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
| 547 |
+
tokens_length = _tokens.size(1)
|
| 548 |
+
_attention_mask[:tokens_length] = False
|
| 549 |
+
|
| 550 |
+
assert tokens_length == _labels.size(
|
| 551 |
+
1
|
| 552 |
+
), f"{tokens_length} != {_labels.size(1)}"
|
| 553 |
+
|
| 554 |
+
if tokens_length < max_tokens_length:
|
| 555 |
+
_tokens = F.pad(
|
| 556 |
+
_tokens,
|
| 557 |
+
(0, max_tokens_length - tokens_length),
|
| 558 |
+
value=self.tokenizer.eos_token_id,
|
| 559 |
+
)
|
| 560 |
+
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
| 561 |
+
_labels = F.pad(
|
| 562 |
+
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
tokens.append(_tokens)
|
| 566 |
+
attention_masks.append(_attention_mask)
|
| 567 |
+
labels.append(_labels)
|
| 568 |
+
|
| 569 |
+
tokens = torch.stack(tokens, dim=0)
|
| 570 |
+
attention_masks = torch.stack(attention_masks, dim=0)
|
| 571 |
+
labels = torch.stack(labels, dim=0)
|
| 572 |
+
|
| 573 |
+
return {
|
| 574 |
+
"inputs": tokens,
|
| 575 |
+
"attention_masks": attention_masks,
|
| 576 |
+
"labels": labels,
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class InterleaveDataset(IterableDataset):
|
| 581 |
+
def __init__(
|
| 582 |
+
self,
|
| 583 |
+
datasets: list[IterableDataset],
|
| 584 |
+
probabilities: list[float],
|
| 585 |
+
seed: int = 42,
|
| 586 |
+
):
|
| 587 |
+
super().__init__()
|
| 588 |
+
|
| 589 |
+
self.datasets = datasets
|
| 590 |
+
self.probabilities = probabilities
|
| 591 |
+
self.seed = seed
|
| 592 |
+
|
| 593 |
+
def __iter__(self):
|
| 594 |
+
rng = np.random.default_rng(self.seed)
|
| 595 |
+
dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
| 596 |
+
|
| 597 |
+
while True:
|
| 598 |
+
# Random choice one
|
| 599 |
+
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
| 600 |
+
dataset_iterator = dataset_iterators[dataset_idx]
|
| 601 |
+
|
| 602 |
+
try:
|
| 603 |
+
yield next(dataset_iterator)
|
| 604 |
+
except StopIteration:
|
| 605 |
+
# Exhausted, create a new iterator
|
| 606 |
+
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
| 607 |
+
yield next(dataset_iterators[dataset_idx])
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class TextDataModule(LightningDataModule):
|
| 611 |
+
def __init__(
|
| 612 |
+
self,
|
| 613 |
+
train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
| 614 |
+
val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
| 615 |
+
batch_size: int = 32,
|
| 616 |
+
tokenizer: AutoTokenizer = None,
|
| 617 |
+
max_length: int = 1024,
|
| 618 |
+
num_workers: int = 4,
|
| 619 |
+
):
|
| 620 |
+
super().__init__()
|
| 621 |
+
|
| 622 |
+
self.train_dataset = train_dataset
|
| 623 |
+
self.val_dataset = val_dataset
|
| 624 |
+
self.batch_size = batch_size
|
| 625 |
+
self.tokenizer = tokenizer
|
| 626 |
+
self.max_length = max_length
|
| 627 |
+
self.num_workers = num_workers
|
| 628 |
+
|
| 629 |
+
def train_dataloader(self):
|
| 630 |
+
return DataLoader(
|
| 631 |
+
self.train_dataset,
|
| 632 |
+
batch_size=self.batch_size,
|
| 633 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
| 634 |
+
num_workers=self.num_workers,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
def val_dataloader(self):
|
| 638 |
+
return DataLoader(
|
| 639 |
+
self.val_dataset,
|
| 640 |
+
batch_size=self.batch_size,
|
| 641 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
| 642 |
+
num_workers=self.num_workers,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
if __name__ == "__main__":
|
| 647 |
+
from tqdm import tqdm
|
| 648 |
+
|
| 649 |
+
ds = AutoAugTextDataset(
|
| 650 |
+
["data/protos"],
|
| 651 |
+
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
| 652 |
+
use_speaker=False,
|
| 653 |
+
interactive_prob=1.0,
|
| 654 |
+
use_negative_samples=False,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
for i in ds:
|
| 658 |
+
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
| 659 |
+
# i["labels"][0][i["labels"][0] == -100] = 0
|
| 660 |
+
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
| 661 |
+
break
|
fish_speech/datasets/vqgan.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from lightning import LightningDataModule
|
| 9 |
+
from torch.utils.data import DataLoader, Dataset
|
| 10 |
+
|
| 11 |
+
from fish_speech.utils import RankedLogger
|
| 12 |
+
|
| 13 |
+
logger = RankedLogger(__name__, rank_zero_only=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VQGANDataset(Dataset):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
filelist: str,
|
| 20 |
+
sample_rate: int = 32000,
|
| 21 |
+
hop_length: int = 640,
|
| 22 |
+
slice_frames: Optional[int] = None,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
filelist = Path(filelist)
|
| 27 |
+
root = filelist.parent
|
| 28 |
+
|
| 29 |
+
self.files = [
|
| 30 |
+
root / line.strip()
|
| 31 |
+
for line in filelist.read_text().splitlines()
|
| 32 |
+
if line.strip()
|
| 33 |
+
]
|
| 34 |
+
self.sample_rate = sample_rate
|
| 35 |
+
self.hop_length = hop_length
|
| 36 |
+
self.slice_frames = slice_frames
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.files)
|
| 40 |
+
|
| 41 |
+
def get_item(self, idx):
|
| 42 |
+
file = self.files[idx]
|
| 43 |
+
|
| 44 |
+
audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
| 45 |
+
|
| 46 |
+
# Slice audio and features
|
| 47 |
+
if (
|
| 48 |
+
self.slice_frames is not None
|
| 49 |
+
and audio.shape[0] > self.slice_frames * self.hop_length
|
| 50 |
+
):
|
| 51 |
+
start = np.random.randint(
|
| 52 |
+
0, audio.shape[0] - self.slice_frames * self.hop_length
|
| 53 |
+
)
|
| 54 |
+
audio = audio[start : start + self.slice_frames * self.hop_length]
|
| 55 |
+
|
| 56 |
+
if len(audio) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
max_value = np.abs(audio).max()
|
| 60 |
+
if max_value > 1.0:
|
| 61 |
+
audio = audio / max_value
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"audio": torch.from_numpy(audio),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, idx):
|
| 68 |
+
try:
|
| 69 |
+
return self.get_item(idx)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
import traceback
|
| 72 |
+
|
| 73 |
+
traceback.print_exc()
|
| 74 |
+
logger.error(f"Error loading {self.files[idx]}: {e}")
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class VQGANCollator:
|
| 80 |
+
def __call__(self, batch):
|
| 81 |
+
batch = [x for x in batch if x is not None]
|
| 82 |
+
|
| 83 |
+
audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
| 84 |
+
audio_maxlen = audio_lengths.max()
|
| 85 |
+
|
| 86 |
+
# Rounds up to nearest multiple of 2 (audio_lengths)
|
| 87 |
+
audios = []
|
| 88 |
+
for x in batch:
|
| 89 |
+
audios.append(
|
| 90 |
+
torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
"audios": torch.stack(audios),
|
| 95 |
+
"audio_lengths": audio_lengths,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class VQGANDataModule(LightningDataModule):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
train_dataset: VQGANDataset,
|
| 103 |
+
val_dataset: VQGANDataset,
|
| 104 |
+
batch_size: int = 32,
|
| 105 |
+
num_workers: int = 4,
|
| 106 |
+
val_batch_size: Optional[int] = None,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self.train_dataset = train_dataset
|
| 111 |
+
self.val_dataset = val_dataset
|
| 112 |
+
self.batch_size = batch_size
|
| 113 |
+
self.val_batch_size = val_batch_size or batch_size
|
| 114 |
+
self.num_workers = num_workers
|
| 115 |
+
|
| 116 |
+
def train_dataloader(self):
|
| 117 |
+
return DataLoader(
|
| 118 |
+
self.train_dataset,
|
| 119 |
+
batch_size=self.batch_size,
|
| 120 |
+
collate_fn=VQGANCollator(),
|
| 121 |
+
num_workers=self.num_workers,
|
| 122 |
+
shuffle=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def val_dataloader(self):
|
| 126 |
+
return DataLoader(
|
| 127 |
+
self.val_dataset,
|
| 128 |
+
batch_size=self.val_batch_size,
|
| 129 |
+
collate_fn=VQGANCollator(),
|
| 130 |
+
num_workers=self.num_workers,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
| 136 |
+
dataloader = DataLoader(
|
| 137 |
+
dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
for batch in dataloader:
|
| 141 |
+
print(batch["audios"].shape)
|
| 142 |
+
print(batch["features"].shape)
|
| 143 |
+
print(batch["audio_lengths"])
|
| 144 |
+
print(batch["feature_lengths"])
|
| 145 |
+
break
|
fish_speech/models/text2semantic/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lit_module import TextToSemantic
|
| 2 |
+
|
| 3 |
+
__all__ = ["TextToSemantic"]
|
fish_speech/models/text2semantic/lit_module.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import lightning as L
|
| 5 |
+
import loralib as lora
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
| 9 |
+
|
| 10 |
+
import fish_speech.utils as utils
|
| 11 |
+
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
| 12 |
+
|
| 13 |
+
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class LoraConfig:
|
| 18 |
+
r: int
|
| 19 |
+
lora_alpha: float
|
| 20 |
+
lora_dropout: float = 0.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TextToSemantic(L.LightningModule):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model: NaiveTransformer,
|
| 27 |
+
optimizer: Any,
|
| 28 |
+
lr_scheduler: Any,
|
| 29 |
+
lora_config: Optional[LoraConfig] = None,
|
| 30 |
+
save_lora_only: bool = False,
|
| 31 |
+
use_dpo: bool = False,
|
| 32 |
+
dpo_beta: float = 0.2,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.model = model
|
| 37 |
+
self.optimizer_builder = optimizer
|
| 38 |
+
self.lr_scheduler_builder = lr_scheduler
|
| 39 |
+
self.lora_config = lora_config
|
| 40 |
+
self.save_lora_only = save_lora_only
|
| 41 |
+
self.use_dpo = use_dpo # We don't support reference model yet
|
| 42 |
+
self.dpo_beta = dpo_beta
|
| 43 |
+
|
| 44 |
+
if self.lora_config is not None:
|
| 45 |
+
self.setup_lora()
|
| 46 |
+
|
| 47 |
+
def setup_lora(self):
|
| 48 |
+
# Replace the embedding layer with a LoRA layer
|
| 49 |
+
self.model.embeddings = lora.Embedding(
|
| 50 |
+
num_embeddings=self.model.embeddings.num_embeddings,
|
| 51 |
+
embedding_dim=self.model.embeddings.embedding_dim,
|
| 52 |
+
padding_idx=self.model.embeddings.padding_idx,
|
| 53 |
+
r=self.lora_config.r,
|
| 54 |
+
lora_alpha=self.lora_config.lora_alpha,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Replace output layer with a LoRA layer
|
| 58 |
+
linears = [(self.model, "output")]
|
| 59 |
+
|
| 60 |
+
# Replace all linear layers with LoRA layers
|
| 61 |
+
for layer in self.model.layers:
|
| 62 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
| 63 |
+
linears.extend(
|
| 64 |
+
[
|
| 65 |
+
(layer.feed_forward, "w1"),
|
| 66 |
+
(layer.feed_forward, "w2"),
|
| 67 |
+
(layer.feed_forward, "w3"),
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if hasattr(self.model, "fast_layers"):
|
| 72 |
+
# Dual-AR model
|
| 73 |
+
linears.extend([(self.model, "fast_output")])
|
| 74 |
+
|
| 75 |
+
for layer in self.model.fast_layers:
|
| 76 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
| 77 |
+
linears.extend(
|
| 78 |
+
[
|
| 79 |
+
(layer.feed_forward, "w1"),
|
| 80 |
+
(layer.feed_forward, "w2"),
|
| 81 |
+
(layer.feed_forward, "w3"),
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
for module, layer in linears:
|
| 86 |
+
updated_linear = lora.Linear(
|
| 87 |
+
in_features=getattr(module, layer).in_features,
|
| 88 |
+
out_features=getattr(module, layer).out_features,
|
| 89 |
+
bias=getattr(module, layer).bias,
|
| 90 |
+
r=self.lora_config.r,
|
| 91 |
+
lora_alpha=self.lora_config.lora_alpha,
|
| 92 |
+
lora_dropout=self.lora_config.lora_dropout,
|
| 93 |
+
)
|
| 94 |
+
setattr(module, layer, updated_linear)
|
| 95 |
+
|
| 96 |
+
# Mark only the LoRA layers as trainable
|
| 97 |
+
lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
return self.model(x)
|
| 101 |
+
|
| 102 |
+
def on_save_checkpoint(self, checkpoint):
|
| 103 |
+
if self.lora_config is None or self.save_lora_only is False:
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
# Save only LoRA parameters
|
| 107 |
+
state_dict = checkpoint["state_dict"]
|
| 108 |
+
for name in list(state_dict.keys()):
|
| 109 |
+
if "lora" not in name:
|
| 110 |
+
state_dict.pop(name)
|
| 111 |
+
|
| 112 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
| 113 |
+
# Get weight decay parameters
|
| 114 |
+
weight_decay_parameters, other_parameters = [], []
|
| 115 |
+
for name, param in self.named_parameters():
|
| 116 |
+
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
| 117 |
+
other_parameters.append(param)
|
| 118 |
+
else:
|
| 119 |
+
weight_decay_parameters.append(param)
|
| 120 |
+
|
| 121 |
+
optimizer = self.optimizer_builder(
|
| 122 |
+
[
|
| 123 |
+
{"params": weight_decay_parameters},
|
| 124 |
+
{"params": other_parameters, "weight_decay": 0.0},
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Print the parameters and their weight decay
|
| 129 |
+
for i in optimizer.param_groups:
|
| 130 |
+
log.info(
|
| 131 |
+
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"optimizer": optimizer,
|
| 138 |
+
"lr_scheduler": {
|
| 139 |
+
"scheduler": lr_scheduler,
|
| 140 |
+
"interval": "step",
|
| 141 |
+
},
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
| 145 |
+
def get_batch_logps(
|
| 146 |
+
self,
|
| 147 |
+
logits: torch.FloatTensor,
|
| 148 |
+
labels: torch.LongTensor,
|
| 149 |
+
average_log_prob: bool = False,
|
| 150 |
+
) -> torch.FloatTensor:
|
| 151 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
| 155 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
| 156 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 160 |
+
"""
|
| 161 |
+
assert logits.shape[:-1] == labels.shape
|
| 162 |
+
|
| 163 |
+
labels = labels.clone()
|
| 164 |
+
loss_mask = labels != -100
|
| 165 |
+
|
| 166 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 167 |
+
labels[labels == -100] = 0
|
| 168 |
+
|
| 169 |
+
per_token_logps = torch.gather(
|
| 170 |
+
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
| 171 |
+
).squeeze(-1)
|
| 172 |
+
|
| 173 |
+
if average_log_prob:
|
| 174 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 175 |
+
else:
|
| 176 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 177 |
+
|
| 178 |
+
def _step(self, batch, batch_idx, stage: str):
|
| 179 |
+
is_train = stage == "train"
|
| 180 |
+
|
| 181 |
+
# Do positive and negative samples in the same batch to speed up training
|
| 182 |
+
labels = batch["labels"]
|
| 183 |
+
outputs = self.model(
|
| 184 |
+
inp=batch["inputs"],
|
| 185 |
+
key_padding_mask=batch["attention_masks"],
|
| 186 |
+
)
|
| 187 |
+
token_logits = outputs.token_logits
|
| 188 |
+
codebook_logits = outputs.codebook_logits
|
| 189 |
+
|
| 190 |
+
if self.use_dpo:
|
| 191 |
+
# Firtst half is positive, second half is negative
|
| 192 |
+
token_logits, negative_token_logits = token_logits.chunk(2)
|
| 193 |
+
codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
|
| 194 |
+
labels, negative_labels = labels.chunk(2)
|
| 195 |
+
|
| 196 |
+
# Generate labels
|
| 197 |
+
base_loss = F.cross_entropy(
|
| 198 |
+
token_logits.reshape(-1, token_logits.size(-1)),
|
| 199 |
+
labels[:, 0].reshape(-1),
|
| 200 |
+
ignore_index=-100,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
| 204 |
+
semantic_loss = F.cross_entropy(
|
| 205 |
+
codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
| 206 |
+
codebook_labels.reshape(-1),
|
| 207 |
+
ignore_index=-100,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
loss = base_loss + semantic_loss
|
| 211 |
+
|
| 212 |
+
# If we use dpo
|
| 213 |
+
if self.use_dpo:
|
| 214 |
+
negative_codebook_labels = negative_labels[
|
| 215 |
+
:, 1 : 1 + self.model.config.num_codebooks
|
| 216 |
+
].mT
|
| 217 |
+
|
| 218 |
+
positive_codebook_logps = self.get_batch_logps(
|
| 219 |
+
codebook_logits, codebook_labels
|
| 220 |
+
)
|
| 221 |
+
negative_codebook_logps = self.get_batch_logps(
|
| 222 |
+
negative_codebook_logits, negative_codebook_labels
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# TODO: implement the reference model, avoid screwing up the gradients
|
| 226 |
+
dpo_loss = -F.logsigmoid(
|
| 227 |
+
(positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
|
| 228 |
+
).mean()
|
| 229 |
+
|
| 230 |
+
chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
|
| 231 |
+
rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
|
| 232 |
+
reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
| 233 |
+
chosen_rewards, rejected_rewards = (
|
| 234 |
+
chosen_rewards.mean(),
|
| 235 |
+
rejected_rewards.mean(),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
loss = loss + dpo_loss
|
| 239 |
+
|
| 240 |
+
self.log(
|
| 241 |
+
f"{stage}/dpo_loss",
|
| 242 |
+
dpo_loss,
|
| 243 |
+
on_step=is_train,
|
| 244 |
+
on_epoch=not is_train,
|
| 245 |
+
prog_bar=False,
|
| 246 |
+
logger=True,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
self.log(
|
| 250 |
+
f"{stage}/chosen_rewards",
|
| 251 |
+
chosen_rewards,
|
| 252 |
+
on_step=is_train,
|
| 253 |
+
on_epoch=not is_train,
|
| 254 |
+
prog_bar=False,
|
| 255 |
+
logger=True,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.log(
|
| 259 |
+
f"{stage}/rejected_rewards",
|
| 260 |
+
rejected_rewards,
|
| 261 |
+
on_step=is_train,
|
| 262 |
+
on_epoch=not is_train,
|
| 263 |
+
prog_bar=False,
|
| 264 |
+
logger=True,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.log(
|
| 268 |
+
f"{stage}/reward_accuracy",
|
| 269 |
+
reward_accuracy,
|
| 270 |
+
on_step=is_train,
|
| 271 |
+
on_epoch=not is_train,
|
| 272 |
+
prog_bar=False,
|
| 273 |
+
logger=True,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.log(
|
| 277 |
+
f"{stage}/loss",
|
| 278 |
+
loss,
|
| 279 |
+
on_step=is_train,
|
| 280 |
+
on_epoch=not is_train,
|
| 281 |
+
prog_bar=True,
|
| 282 |
+
logger=True,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.log(
|
| 286 |
+
f"{stage}/base_loss",
|
| 287 |
+
base_loss,
|
| 288 |
+
on_step=is_train,
|
| 289 |
+
on_epoch=not is_train,
|
| 290 |
+
prog_bar=False,
|
| 291 |
+
logger=True,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
self.log(
|
| 295 |
+
f"{stage}/semantic_loss",
|
| 296 |
+
semantic_loss,
|
| 297 |
+
on_step=is_train,
|
| 298 |
+
on_epoch=not is_train,
|
| 299 |
+
prog_bar=False,
|
| 300 |
+
logger=True,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Top-5 accuracy
|
| 304 |
+
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
| 305 |
+
self.log(
|
| 306 |
+
f"{stage}/top_5_accuracy",
|
| 307 |
+
accuracy,
|
| 308 |
+
on_step=is_train,
|
| 309 |
+
on_epoch=not is_train,
|
| 310 |
+
prog_bar=True,
|
| 311 |
+
logger=True,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
|
| 315 |
+
accuracy = self.get_accuracy(
|
| 316 |
+
codebook_logits[:, :, : self.model.config.num_in_codebooks],
|
| 317 |
+
codebook_labels[:, :, : self.model.config.num_in_codebooks],
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.log(
|
| 321 |
+
f"{stage}/top_5_accuracy_in",
|
| 322 |
+
accuracy,
|
| 323 |
+
on_step=is_train,
|
| 324 |
+
on_epoch=not is_train,
|
| 325 |
+
prog_bar=True,
|
| 326 |
+
logger=True,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return loss
|
| 330 |
+
|
| 331 |
+
def get_accuracy(self, logits, labels):
|
| 332 |
+
_, indices = logits.topk(5, dim=-1)
|
| 333 |
+
correct = indices.eq(labels.unsqueeze(-1))
|
| 334 |
+
correct[labels == -100] = 0
|
| 335 |
+
correct = correct.sum()
|
| 336 |
+
accuracy = correct / (labels != -100).sum()
|
| 337 |
+
|
| 338 |
+
return accuracy
|
| 339 |
+
|
| 340 |
+
def training_step(self, batch, batch_idx):
|
| 341 |
+
return self._step(batch, batch_idx, "train")
|
| 342 |
+
|
| 343 |
+
def validation_step(self, batch, batch_idx):
|
| 344 |
+
return self._step(batch, batch_idx, "val")
|
fish_speech/models/text2semantic/llama.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def find_multiple(n: int, k: int) -> int:
|
| 14 |
+
if n % k == 0:
|
| 15 |
+
return n
|
| 16 |
+
return n + k - (n % k)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class BaseModelArgs:
|
| 21 |
+
vocab_size: int = 32000
|
| 22 |
+
n_layer: int = 32
|
| 23 |
+
n_head: int = 32
|
| 24 |
+
dim: int = 4096
|
| 25 |
+
intermediate_size: int = None
|
| 26 |
+
n_local_heads: int = -1
|
| 27 |
+
head_dim: int = 64
|
| 28 |
+
rope_base: float = 10000
|
| 29 |
+
norm_eps: float = 1e-5
|
| 30 |
+
max_seq_len: int = 2048
|
| 31 |
+
dropout: float = 0.0
|
| 32 |
+
|
| 33 |
+
# Codebook configs
|
| 34 |
+
codebook_size: int = 160
|
| 35 |
+
num_codebooks: int = 4
|
| 36 |
+
num_in_codebooks: Optional[int] = None
|
| 37 |
+
codebook_padding_idx: int = 0
|
| 38 |
+
|
| 39 |
+
# Gradient checkpointing
|
| 40 |
+
use_gradient_checkpointing: bool = True
|
| 41 |
+
|
| 42 |
+
def __post_init__(self):
|
| 43 |
+
if self.n_local_heads == -1:
|
| 44 |
+
self.n_local_heads = self.n_head
|
| 45 |
+
if self.intermediate_size is None:
|
| 46 |
+
hidden_dim = 4 * self.dim
|
| 47 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 48 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
| 49 |
+
if self.num_in_codebooks is None:
|
| 50 |
+
self.num_in_codebooks = self.num_codebooks
|
| 51 |
+
self.head_dim = self.dim // self.n_head
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class NaiveModelArgs(BaseModelArgs):
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class DualARModelArgs(BaseModelArgs):
|
| 61 |
+
n_fast_layer: int = 4
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class KVCache(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
|
| 70 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 71 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 72 |
+
|
| 73 |
+
def update(self, input_pos, k_val, v_val):
|
| 74 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
| 75 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
| 76 |
+
|
| 77 |
+
k_out = self.k_cache
|
| 78 |
+
v_out = self.v_cache
|
| 79 |
+
k_out[:, :, input_pos] = k_val
|
| 80 |
+
v_out[:, :, input_pos] = v_val
|
| 81 |
+
|
| 82 |
+
return k_out, v_out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class TransformerForwardResult:
|
| 87 |
+
token_logits: Tensor
|
| 88 |
+
codebook_logits: Tensor
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class BaseTransformerForwardResult:
|
| 93 |
+
logits: Tensor
|
| 94 |
+
hidden_states: Tensor
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class BaseTransformer(nn.Module):
|
| 98 |
+
def __init__(self, config: BaseModelArgs) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.config = config
|
| 101 |
+
|
| 102 |
+
# Slow transformer
|
| 103 |
+
self.embeddings = nn.Embedding(
|
| 104 |
+
config.vocab_size + config.codebook_size * config.num_in_codebooks,
|
| 105 |
+
config.dim,
|
| 106 |
+
)
|
| 107 |
+
self.layers = nn.ModuleList(
|
| 108 |
+
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
| 109 |
+
)
|
| 110 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 111 |
+
self.output = nn.Linear(
|
| 112 |
+
config.dim,
|
| 113 |
+
config.vocab_size,
|
| 114 |
+
bias=False,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.register_buffer(
|
| 118 |
+
"freqs_cis",
|
| 119 |
+
precompute_freqs_cis(
|
| 120 |
+
config.max_seq_len,
|
| 121 |
+
config.dim // config.n_head,
|
| 122 |
+
config.rope_base,
|
| 123 |
+
),
|
| 124 |
+
persistent=False,
|
| 125 |
+
)
|
| 126 |
+
self.register_buffer(
|
| 127 |
+
"causal_mask",
|
| 128 |
+
torch.tril(
|
| 129 |
+
torch.ones(
|
| 130 |
+
config.max_seq_len,
|
| 131 |
+
config.max_seq_len,
|
| 132 |
+
dtype=torch.bool,
|
| 133 |
+
)
|
| 134 |
+
),
|
| 135 |
+
persistent=False,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# For kv cache
|
| 139 |
+
self.max_batch_size = -1
|
| 140 |
+
self.max_seq_len = -1
|
| 141 |
+
|
| 142 |
+
def setup_caches(
|
| 143 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
| 144 |
+
):
|
| 145 |
+
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
head_dim = self.config.dim // self.config.n_head
|
| 149 |
+
max_seq_len = find_multiple(max_seq_len, 8)
|
| 150 |
+
self.max_seq_len = max_seq_len
|
| 151 |
+
self.max_batch_size = max_batch_size
|
| 152 |
+
|
| 153 |
+
for b in self.layers:
|
| 154 |
+
b.attention.kv_cache = KVCache(
|
| 155 |
+
max_batch_size,
|
| 156 |
+
max_seq_len,
|
| 157 |
+
self.config.n_local_heads,
|
| 158 |
+
head_dim,
|
| 159 |
+
dtype=dtype,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def embed(self, x: Tensor) -> Tensor:
|
| 163 |
+
vocab_embeds = [self.embeddings(x[:, 0])]
|
| 164 |
+
for i in range(self.config.num_in_codebooks):
|
| 165 |
+
emb = self.embeddings(
|
| 166 |
+
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
| 167 |
+
)
|
| 168 |
+
emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
|
| 169 |
+
vocab_embeds.append(emb)
|
| 170 |
+
|
| 171 |
+
x = torch.stack(vocab_embeds, dim=3)
|
| 172 |
+
x = x.sum(dim=3)
|
| 173 |
+
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def forward(
|
| 177 |
+
self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
| 178 |
+
) -> BaseTransformerForwardResult:
|
| 179 |
+
# x: (batch, num_codebooks + 1, seq_len)
|
| 180 |
+
seq_len = inp.size(2)
|
| 181 |
+
|
| 182 |
+
# Here we want to merge the embeddings of the codebooks
|
| 183 |
+
x = self.embed(inp)
|
| 184 |
+
|
| 185 |
+
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
| 186 |
+
freqs_cis = self.freqs_cis[:seq_len]
|
| 187 |
+
|
| 188 |
+
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
| 189 |
+
# That is, FALSE means masked out
|
| 190 |
+
# To maintain consistency, key_padding_mask use TRUE to mask out
|
| 191 |
+
if key_padding_mask is not None:
|
| 192 |
+
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
| 193 |
+
|
| 194 |
+
for layer in self.layers:
|
| 195 |
+
if self.config.use_gradient_checkpointing and self.training:
|
| 196 |
+
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
| 197 |
+
else:
|
| 198 |
+
x = layer(x, freqs_cis, mask)
|
| 199 |
+
|
| 200 |
+
# We got slow_out here
|
| 201 |
+
slow_out = self.norm(x)
|
| 202 |
+
token_logits = self.output(slow_out)
|
| 203 |
+
|
| 204 |
+
return BaseTransformerForwardResult(
|
| 205 |
+
logits=token_logits,
|
| 206 |
+
hidden_states=x,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def forward_generate(
|
| 210 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
| 211 |
+
) -> BaseTransformerForwardResult:
|
| 212 |
+
# This is used for generation, optimized for torch compile
|
| 213 |
+
assert (
|
| 214 |
+
self.max_seq_len != -1 and self.max_batch_size != -1
|
| 215 |
+
), "Please call setup_caches before forward_generate"
|
| 216 |
+
|
| 217 |
+
x = self.embed(x)
|
| 218 |
+
|
| 219 |
+
mask = self.causal_mask[
|
| 220 |
+
None, None, input_pos, : self.max_seq_len
|
| 221 |
+
] # (B, N, Q, K)
|
| 222 |
+
freqs_cis = self.freqs_cis[input_pos]
|
| 223 |
+
|
| 224 |
+
for layer in self.layers:
|
| 225 |
+
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
| 226 |
+
|
| 227 |
+
# If prefill, we only calculate the logits of last token
|
| 228 |
+
if x.size(1) > 1:
|
| 229 |
+
x = x[:, -1:]
|
| 230 |
+
|
| 231 |
+
# We got slow_out here
|
| 232 |
+
slow_out = self.norm(x)
|
| 233 |
+
token_logits = self.output(slow_out)
|
| 234 |
+
|
| 235 |
+
return BaseTransformerForwardResult(
|
| 236 |
+
logits=token_logits,
|
| 237 |
+
hidden_states=x,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class NaiveTransformer(BaseTransformer):
|
| 242 |
+
def __init__(self, config: NaiveModelArgs) -> None:
|
| 243 |
+
super().__init__(config)
|
| 244 |
+
|
| 245 |
+
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 246 |
+
self.codebook_output = nn.Linear(
|
| 247 |
+
config.dim,
|
| 248 |
+
config.codebook_size * config.num_codebooks,
|
| 249 |
+
bias=False,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
| 253 |
+
token_logits = result.logits
|
| 254 |
+
x = result.hidden_states
|
| 255 |
+
|
| 256 |
+
# Codebook
|
| 257 |
+
codebook_logits = self.codebook_output(self.codebook_norm(x))
|
| 258 |
+
codebook_logits = rearrange(
|
| 259 |
+
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return TransformerForwardResult(
|
| 263 |
+
token_logits=token_logits,
|
| 264 |
+
codebook_logits=codebook_logits,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
| 269 |
+
) -> TransformerForwardResult:
|
| 270 |
+
result = super().forward(inp, key_padding_mask)
|
| 271 |
+
return self.decode(result)
|
| 272 |
+
|
| 273 |
+
def forward_generate(
|
| 274 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
| 275 |
+
) -> TransformerForwardResult:
|
| 276 |
+
result = super().forward_generate(x, input_pos)
|
| 277 |
+
return self.decode(result)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class DualARTransformer(BaseTransformer):
|
| 281 |
+
def __init__(self, config: DualARModelArgs) -> None:
|
| 282 |
+
super().__init__(config)
|
| 283 |
+
|
| 284 |
+
# Fast transformer
|
| 285 |
+
self.fast_embeddings = nn.Embedding(
|
| 286 |
+
config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# The equivalent bs is so large that sdpa doesn't work
|
| 290 |
+
self.fast_layers = nn.ModuleList(
|
| 291 |
+
TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
| 292 |
+
)
|
| 293 |
+
self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 294 |
+
self.fast_output = nn.Linear(
|
| 295 |
+
config.dim,
|
| 296 |
+
config.codebook_size,
|
| 297 |
+
bias=False,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def setup_caches(
|
| 301 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
| 302 |
+
):
|
| 303 |
+
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
| 304 |
+
|
| 305 |
+
head_dim = self.config.dim // self.config.n_head
|
| 306 |
+
|
| 307 |
+
# Fast transformer
|
| 308 |
+
# The max seq len here is the number of codebooks
|
| 309 |
+
for b in self.fast_layers:
|
| 310 |
+
b.attention.kv_cache = KVCache(
|
| 311 |
+
max_batch_size,
|
| 312 |
+
self.config.num_codebooks,
|
| 313 |
+
self.config.n_local_heads,
|
| 314 |
+
head_dim,
|
| 315 |
+
dtype=dtype,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def forward(
|
| 319 |
+
self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
| 320 |
+
) -> TransformerForwardResult:
|
| 321 |
+
parent_result = super().forward(inp, key_padding_mask)
|
| 322 |
+
token_logits = parent_result.logits
|
| 323 |
+
x = parent_result.hidden_states
|
| 324 |
+
|
| 325 |
+
# Fast transformer
|
| 326 |
+
fast_seq_len = self.config.num_codebooks
|
| 327 |
+
fast_mask = self.causal_mask[
|
| 328 |
+
None, None, :fast_seq_len, :fast_seq_len
|
| 329 |
+
] # (B, N, Q, K)
|
| 330 |
+
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
| 331 |
+
|
| 332 |
+
# Drop the last token and rotate left
|
| 333 |
+
codebooks = inp[:, 1:-1, 1:]
|
| 334 |
+
codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
|
| 335 |
+
codebook_embeddings = self.fast_embeddings(codebooks)
|
| 336 |
+
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
| 337 |
+
b, s = x.size(0), x.size(2)
|
| 338 |
+
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
| 339 |
+
|
| 340 |
+
# Remove padded part
|
| 341 |
+
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
| 342 |
+
codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
|
| 343 |
+
x_bs, x_len = x.size(0), x.size(1)
|
| 344 |
+
x = x[~codebook_mask]
|
| 345 |
+
|
| 346 |
+
for layer in self.fast_layers:
|
| 347 |
+
if self.config.use_gradient_checkpointing and self.training:
|
| 348 |
+
x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
|
| 349 |
+
else:
|
| 350 |
+
x = layer(x, fast_freqs_cis, fast_mask)
|
| 351 |
+
|
| 352 |
+
# unflatten the batch and num_codebooks
|
| 353 |
+
fast_out = self.fast_norm(x)
|
| 354 |
+
codebook_logits = self.fast_output(fast_out)
|
| 355 |
+
|
| 356 |
+
# Re-pad the codebook_logits
|
| 357 |
+
buffer = torch.zeros(
|
| 358 |
+
x_bs,
|
| 359 |
+
x_len,
|
| 360 |
+
codebook_logits.size(-1),
|
| 361 |
+
device=codebook_logits.device,
|
| 362 |
+
dtype=codebook_logits.dtype,
|
| 363 |
+
)
|
| 364 |
+
buffer[~codebook_mask] = codebook_logits
|
| 365 |
+
codebook_logits = buffer
|
| 366 |
+
|
| 367 |
+
assert codebook_logits.shape[1] == self.config.num_codebooks
|
| 368 |
+
codebook_logits = rearrange(
|
| 369 |
+
codebook_logits,
|
| 370 |
+
"(b s) n d -> b s n d",
|
| 371 |
+
b=b,
|
| 372 |
+
s=s,
|
| 373 |
+
n=self.config.num_codebooks,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
return TransformerForwardResult(
|
| 377 |
+
token_logits=token_logits,
|
| 378 |
+
codebook_logits=codebook_logits,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def forward_generate_fast(
|
| 382 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
| 383 |
+
) -> Tensor:
|
| 384 |
+
# Fast transformer
|
| 385 |
+
x = x.view(1, 1, -1)
|
| 386 |
+
|
| 387 |
+
fast_mask = self.causal_mask[
|
| 388 |
+
None, None, input_pos, : self.config.num_codebooks
|
| 389 |
+
] # (B, N, Q, K)
|
| 390 |
+
fast_freqs_cis = self.freqs_cis[input_pos]
|
| 391 |
+
|
| 392 |
+
for layer in self.fast_layers:
|
| 393 |
+
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
| 394 |
+
|
| 395 |
+
# unflatten the batch and num_codebooks
|
| 396 |
+
fast_out = self.fast_norm(x) # only take the last token
|
| 397 |
+
codebook_logits = self.fast_output(fast_out)
|
| 398 |
+
|
| 399 |
+
return codebook_logits
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class TransformerBlock(nn.Module):
|
| 403 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.attention = Attention(config, use_sdpa=use_sdpa)
|
| 406 |
+
self.feed_forward = FeedForward(config)
|
| 407 |
+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
| 408 |
+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
| 409 |
+
|
| 410 |
+
def forward(
|
| 411 |
+
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
|
| 412 |
+
) -> Tensor:
|
| 413 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
| 414 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class Attention(nn.Module):
|
| 419 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
|
| 420 |
+
super().__init__()
|
| 421 |
+
assert config.dim % config.n_head == 0
|
| 422 |
+
|
| 423 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
| 424 |
+
# key, query, value projections for all heads, but in a batch
|
| 425 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
| 426 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
| 427 |
+
self.kv_cache = None
|
| 428 |
+
|
| 429 |
+
self.dropout = config.dropout
|
| 430 |
+
self.n_head = config.n_head
|
| 431 |
+
self.head_dim = config.head_dim
|
| 432 |
+
self.n_local_heads = config.n_local_heads
|
| 433 |
+
self.dim = config.dim
|
| 434 |
+
self.use_sdpa = use_sdpa
|
| 435 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 436 |
+
|
| 437 |
+
def load_hook(self, state_dict, prefix, *args):
|
| 438 |
+
if prefix + "wq.weight" in state_dict:
|
| 439 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
| 440 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
| 441 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
| 442 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
| 443 |
+
|
| 444 |
+
def forward(
|
| 445 |
+
self,
|
| 446 |
+
x: Tensor,
|
| 447 |
+
freqs_cis: Tensor,
|
| 448 |
+
mask: Tensor,
|
| 449 |
+
input_pos: Optional[Tensor] = None,
|
| 450 |
+
) -> Tensor:
|
| 451 |
+
bsz, seqlen, _ = x.shape
|
| 452 |
+
|
| 453 |
+
kv_size = self.n_local_heads * self.head_dim
|
| 454 |
+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
| 455 |
+
|
| 456 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 457 |
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| 458 |
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| 459 |
+
|
| 460 |
+
q = apply_rotary_emb(q, freqs_cis)
|
| 461 |
+
k = apply_rotary_emb(k, freqs_cis)
|
| 462 |
+
|
| 463 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 464 |
+
|
| 465 |
+
if self.kv_cache is not None:
|
| 466 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
| 467 |
+
|
| 468 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 469 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 470 |
+
|
| 471 |
+
if self.use_sdpa:
|
| 472 |
+
y = F.scaled_dot_product_attention(
|
| 473 |
+
q,
|
| 474 |
+
k,
|
| 475 |
+
v,
|
| 476 |
+
attn_mask=mask,
|
| 477 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
y = self.eq_scaled_dot_product_attention(
|
| 481 |
+
q,
|
| 482 |
+
k,
|
| 483 |
+
v,
|
| 484 |
+
attn_mask=mask,
|
| 485 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
| 489 |
+
|
| 490 |
+
return self.wo(y)
|
| 491 |
+
|
| 492 |
+
def eq_scaled_dot_product_attention(
|
| 493 |
+
self,
|
| 494 |
+
query,
|
| 495 |
+
key,
|
| 496 |
+
value,
|
| 497 |
+
attn_mask=None,
|
| 498 |
+
dropout_p=0.0,
|
| 499 |
+
) -> torch.Tensor:
|
| 500 |
+
# This is a standard scaled dot product attention
|
| 501 |
+
# It's low efficient, but it doesn't raise cuda error
|
| 502 |
+
|
| 503 |
+
L, S = query.size(-2), key.size(-2)
|
| 504 |
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
| 505 |
+
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
|
| 506 |
+
|
| 507 |
+
if attn_mask is not None:
|
| 508 |
+
if attn_mask.dtype == torch.bool:
|
| 509 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 510 |
+
else:
|
| 511 |
+
attn_bias += attn_mask
|
| 512 |
+
|
| 513 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
| 514 |
+
attn_weight += attn_bias
|
| 515 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 516 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
| 517 |
+
|
| 518 |
+
return attn_weight @ value
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class FeedForward(nn.Module):
|
| 522 |
+
def __init__(self, config: BaseModelArgs) -> None:
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 525 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 526 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
| 527 |
+
|
| 528 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 529 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class RMSNorm(nn.Module):
|
| 533 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.eps = eps
|
| 536 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 537 |
+
|
| 538 |
+
def _norm(self, x):
|
| 539 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 540 |
+
|
| 541 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 542 |
+
output = self._norm(x.float()).type_as(x)
|
| 543 |
+
return output * self.weight
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
|
| 547 |
+
freqs = 1.0 / (
|
| 548 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
| 549 |
+
)
|
| 550 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 551 |
+
freqs = torch.outer(t, freqs)
|
| 552 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 553 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 554 |
+
return cache.to(dtype=torch.bfloat16)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
| 558 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 559 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
| 560 |
+
x_out2 = torch.stack(
|
| 561 |
+
[
|
| 562 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 563 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 564 |
+
],
|
| 565 |
+
-1,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
x_out2 = x_out2.flatten(3)
|
| 569 |
+
return x_out2.type_as(x)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
if __name__ == "__main__":
|
| 573 |
+
args = DualARModelArgs(
|
| 574 |
+
max_seq_len=4096,
|
| 575 |
+
vocab_size=32312,
|
| 576 |
+
n_layer=12,
|
| 577 |
+
n_fast_layer=4,
|
| 578 |
+
n_head=12,
|
| 579 |
+
dim=768,
|
| 580 |
+
rope_base=10000,
|
| 581 |
+
norm_eps=1e-5,
|
| 582 |
+
codebook_size=128,
|
| 583 |
+
num_codebooks=4,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
model = DualARTransformer(args)
|
| 587 |
+
model = model.cuda().bfloat16()
|
| 588 |
+
print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
|
| 589 |
+
|
| 590 |
+
inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
|
| 591 |
+
key_padding_mask = torch.zeros(2, 128).bool().cuda()
|
| 592 |
+
key_padding_mask[0, 2:] = True
|
| 593 |
+
x1 = model(inputs, key_padding_mask=key_padding_mask)
|
| 594 |
+
print(x1.token_logits.shape)
|
| 595 |
+
print(x1.codebook_logits.shape)
|
fish_speech/models/vqgan/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lit_module import VQGAN
|
| 2 |
+
|
| 3 |
+
__all__ = ["VQGAN"]
|
fish_speech/models/vqgan/lit_module.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import lightning as L
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import wandb
|
| 9 |
+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
| 10 |
+
from matplotlib import pyplot as plt
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from fish_speech.models.vqgan.modules.discriminator import Discriminator
|
| 14 |
+
from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
| 15 |
+
from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VQGAN(L.LightningModule):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
optimizer: Callable,
|
| 22 |
+
lr_scheduler: Callable,
|
| 23 |
+
encoder: WaveNet,
|
| 24 |
+
quantizer: nn.Module,
|
| 25 |
+
decoder: WaveNet,
|
| 26 |
+
discriminator: Discriminator,
|
| 27 |
+
vocoder: nn.Module,
|
| 28 |
+
encode_mel_transform: nn.Module,
|
| 29 |
+
gt_mel_transform: nn.Module,
|
| 30 |
+
weight_adv: float = 1.0,
|
| 31 |
+
weight_vq: float = 1.0,
|
| 32 |
+
weight_mel: float = 1.0,
|
| 33 |
+
sampling_rate: int = 44100,
|
| 34 |
+
freeze_encoder: bool = False,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# Model parameters
|
| 39 |
+
self.optimizer_builder = optimizer
|
| 40 |
+
self.lr_scheduler_builder = lr_scheduler
|
| 41 |
+
|
| 42 |
+
# Modules
|
| 43 |
+
self.encoder = encoder
|
| 44 |
+
self.quantizer = quantizer
|
| 45 |
+
self.decoder = decoder
|
| 46 |
+
self.vocoder = vocoder
|
| 47 |
+
self.discriminator = discriminator
|
| 48 |
+
self.encode_mel_transform = encode_mel_transform
|
| 49 |
+
self.gt_mel_transform = gt_mel_transform
|
| 50 |
+
|
| 51 |
+
# A simple linear layer to project quality to condition channels
|
| 52 |
+
self.quality_projection = nn.Linear(1, 768)
|
| 53 |
+
|
| 54 |
+
# Freeze vocoder
|
| 55 |
+
for param in self.vocoder.parameters():
|
| 56 |
+
param.requires_grad = False
|
| 57 |
+
|
| 58 |
+
# Loss weights
|
| 59 |
+
self.weight_adv = weight_adv
|
| 60 |
+
self.weight_vq = weight_vq
|
| 61 |
+
self.weight_mel = weight_mel
|
| 62 |
+
|
| 63 |
+
# Other parameters
|
| 64 |
+
self.sampling_rate = sampling_rate
|
| 65 |
+
|
| 66 |
+
# Disable strict loading
|
| 67 |
+
self.strict_loading = False
|
| 68 |
+
|
| 69 |
+
# If encoder is frozen
|
| 70 |
+
if freeze_encoder:
|
| 71 |
+
for param in self.encoder.parameters():
|
| 72 |
+
param.requires_grad = False
|
| 73 |
+
|
| 74 |
+
for param in self.quantizer.parameters():
|
| 75 |
+
param.requires_grad = False
|
| 76 |
+
|
| 77 |
+
self.automatic_optimization = False
|
| 78 |
+
|
| 79 |
+
def on_save_checkpoint(self, checkpoint):
|
| 80 |
+
# Do not save vocoder
|
| 81 |
+
state_dict = checkpoint["state_dict"]
|
| 82 |
+
for name in list(state_dict.keys()):
|
| 83 |
+
if "vocoder" in name:
|
| 84 |
+
state_dict.pop(name)
|
| 85 |
+
|
| 86 |
+
def configure_optimizers(self):
|
| 87 |
+
optimizer_generator = self.optimizer_builder(
|
| 88 |
+
itertools.chain(
|
| 89 |
+
self.encoder.parameters(),
|
| 90 |
+
self.quantizer.parameters(),
|
| 91 |
+
self.decoder.parameters(),
|
| 92 |
+
self.quality_projection.parameters(),
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
optimizer_discriminator = self.optimizer_builder(
|
| 96 |
+
self.discriminator.parameters()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
| 100 |
+
lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
|
| 101 |
+
|
| 102 |
+
return (
|
| 103 |
+
{
|
| 104 |
+
"optimizer": optimizer_generator,
|
| 105 |
+
"lr_scheduler": {
|
| 106 |
+
"scheduler": lr_scheduler_generator,
|
| 107 |
+
"interval": "step",
|
| 108 |
+
"name": "optimizer/generator",
|
| 109 |
+
},
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"optimizer": optimizer_discriminator,
|
| 113 |
+
"lr_scheduler": {
|
| 114 |
+
"scheduler": lr_scheduler_discriminator,
|
| 115 |
+
"interval": "step",
|
| 116 |
+
"name": "optimizer/discriminator",
|
| 117 |
+
},
|
| 118 |
+
},
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def training_step(self, batch, batch_idx):
|
| 122 |
+
optim_g, optim_d = self.optimizers()
|
| 123 |
+
|
| 124 |
+
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
| 125 |
+
|
| 126 |
+
audios = audios.float()
|
| 127 |
+
audios = audios[:, None, :]
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
encoded_mels = self.encode_mel_transform(audios)
|
| 131 |
+
gt_mels = self.gt_mel_transform(audios)
|
| 132 |
+
quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
|
| 133 |
+
quality = quality.unsqueeze(-1)
|
| 134 |
+
|
| 135 |
+
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
| 136 |
+
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
| 137 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
| 138 |
+
gt_mels = gt_mels * mel_masks_float_conv
|
| 139 |
+
encoded_mels = encoded_mels * mel_masks_float_conv
|
| 140 |
+
|
| 141 |
+
# Encode
|
| 142 |
+
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
| 143 |
+
|
| 144 |
+
# Quantize
|
| 145 |
+
vq_result = self.quantizer(encoded_features)
|
| 146 |
+
loss_vq = getattr("vq_result", "loss", 0.0)
|
| 147 |
+
vq_recon_features = vq_result.z * mel_masks_float_conv
|
| 148 |
+
vq_recon_features = (
|
| 149 |
+
vq_recon_features + self.quality_projection(quality)[:, :, None]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# VQ Decode
|
| 153 |
+
gen_mel = (
|
| 154 |
+
self.decoder(
|
| 155 |
+
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
| 156 |
+
condition=vq_recon_features,
|
| 157 |
+
)
|
| 158 |
+
* mel_masks_float_conv
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Discriminator
|
| 162 |
+
real_logits = self.discriminator(gt_mels)
|
| 163 |
+
fake_logits = self.discriminator(gen_mel.detach())
|
| 164 |
+
d_mask = F.interpolate(
|
| 165 |
+
mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
|
| 169 |
+
loss_fake = avg_with_mask(fake_logits**2, d_mask)
|
| 170 |
+
|
| 171 |
+
loss_d = loss_real + loss_fake
|
| 172 |
+
|
| 173 |
+
self.log(
|
| 174 |
+
"train/discriminator/loss",
|
| 175 |
+
loss_d,
|
| 176 |
+
on_step=True,
|
| 177 |
+
on_epoch=False,
|
| 178 |
+
prog_bar=True,
|
| 179 |
+
logger=True,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Discriminator backward
|
| 183 |
+
optim_d.zero_grad()
|
| 184 |
+
self.manual_backward(loss_d)
|
| 185 |
+
self.clip_gradients(
|
| 186 |
+
optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
| 187 |
+
)
|
| 188 |
+
optim_d.step()
|
| 189 |
+
|
| 190 |
+
# Mel Loss, applying l1, using a weighted sum
|
| 191 |
+
mel_distance = (
|
| 192 |
+
gen_mel - gt_mels
|
| 193 |
+
).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
|
| 194 |
+
loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
|
| 195 |
+
loss_mel_mid_freq = avg_with_mask(
|
| 196 |
+
mel_distance[:, 40:70, :], mel_masks_float_conv
|
| 197 |
+
)
|
| 198 |
+
loss_mel_high_freq = avg_with_mask(
|
| 199 |
+
mel_distance[:, 70:, :], mel_masks_float_conv
|
| 200 |
+
)
|
| 201 |
+
loss_mel = (
|
| 202 |
+
loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Adversarial Loss
|
| 206 |
+
fake_logits = self.discriminator(gen_mel)
|
| 207 |
+
loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
|
| 208 |
+
|
| 209 |
+
# Total loss
|
| 210 |
+
loss = (
|
| 211 |
+
self.weight_vq * loss_vq
|
| 212 |
+
+ self.weight_mel * loss_mel
|
| 213 |
+
+ self.weight_adv * loss_adv
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Log losses
|
| 217 |
+
self.log(
|
| 218 |
+
"train/generator/loss",
|
| 219 |
+
loss,
|
| 220 |
+
on_step=True,
|
| 221 |
+
on_epoch=False,
|
| 222 |
+
prog_bar=True,
|
| 223 |
+
logger=True,
|
| 224 |
+
)
|
| 225 |
+
self.log(
|
| 226 |
+
"train/generator/loss_vq",
|
| 227 |
+
loss_vq,
|
| 228 |
+
on_step=True,
|
| 229 |
+
on_epoch=False,
|
| 230 |
+
prog_bar=False,
|
| 231 |
+
logger=True,
|
| 232 |
+
)
|
| 233 |
+
self.log(
|
| 234 |
+
"train/generator/loss_mel",
|
| 235 |
+
loss_mel,
|
| 236 |
+
on_step=True,
|
| 237 |
+
on_epoch=False,
|
| 238 |
+
prog_bar=False,
|
| 239 |
+
logger=True,
|
| 240 |
+
)
|
| 241 |
+
self.log(
|
| 242 |
+
"train/generator/loss_adv",
|
| 243 |
+
loss_adv,
|
| 244 |
+
on_step=True,
|
| 245 |
+
on_epoch=False,
|
| 246 |
+
prog_bar=False,
|
| 247 |
+
logger=True,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Generator backward
|
| 251 |
+
optim_g.zero_grad()
|
| 252 |
+
self.manual_backward(loss)
|
| 253 |
+
self.clip_gradients(
|
| 254 |
+
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
| 255 |
+
)
|
| 256 |
+
optim_g.step()
|
| 257 |
+
|
| 258 |
+
scheduler_g, scheduler_d = self.lr_schedulers()
|
| 259 |
+
scheduler_g.step()
|
| 260 |
+
scheduler_d.step()
|
| 261 |
+
|
| 262 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
| 263 |
+
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
| 264 |
+
|
| 265 |
+
audios = audios.float()
|
| 266 |
+
audios = audios[:, None, :]
|
| 267 |
+
|
| 268 |
+
encoded_mels = self.encode_mel_transform(audios)
|
| 269 |
+
gt_mels = self.gt_mel_transform(audios)
|
| 270 |
+
|
| 271 |
+
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
| 272 |
+
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
| 273 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
| 274 |
+
gt_mels = gt_mels * mel_masks_float_conv
|
| 275 |
+
encoded_mels = encoded_mels * mel_masks_float_conv
|
| 276 |
+
|
| 277 |
+
# Encode
|
| 278 |
+
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
| 279 |
+
|
| 280 |
+
# Quantize
|
| 281 |
+
vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
|
| 282 |
+
vq_recon_features = (
|
| 283 |
+
vq_recon_features
|
| 284 |
+
+ self.quality_projection(
|
| 285 |
+
torch.ones(
|
| 286 |
+
vq_recon_features.shape[0], 1, device=vq_recon_features.device
|
| 287 |
+
)
|
| 288 |
+
* 2
|
| 289 |
+
)[:, :, None]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# VQ Decode
|
| 293 |
+
gen_aux_mels = (
|
| 294 |
+
self.decoder(
|
| 295 |
+
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
| 296 |
+
condition=vq_recon_features,
|
| 297 |
+
)
|
| 298 |
+
* mel_masks_float_conv
|
| 299 |
+
)
|
| 300 |
+
loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
|
| 301 |
+
|
| 302 |
+
self.log(
|
| 303 |
+
"val/loss_mel",
|
| 304 |
+
loss_mel,
|
| 305 |
+
on_step=False,
|
| 306 |
+
on_epoch=True,
|
| 307 |
+
prog_bar=False,
|
| 308 |
+
logger=True,
|
| 309 |
+
sync_dist=True,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
recon_audios = self.vocoder(gt_mels)
|
| 313 |
+
gen_aux_audios = self.vocoder(gen_aux_mels)
|
| 314 |
+
|
| 315 |
+
# only log the first batch
|
| 316 |
+
if batch_idx != 0:
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
for idx, (
|
| 320 |
+
gt_mel,
|
| 321 |
+
gen_aux_mel,
|
| 322 |
+
audio,
|
| 323 |
+
gen_aux_audio,
|
| 324 |
+
recon_audio,
|
| 325 |
+
audio_len,
|
| 326 |
+
) in enumerate(
|
| 327 |
+
zip(
|
| 328 |
+
gt_mels,
|
| 329 |
+
gen_aux_mels,
|
| 330 |
+
audios.cpu().float(),
|
| 331 |
+
gen_aux_audios.cpu().float(),
|
| 332 |
+
recon_audios.cpu().float(),
|
| 333 |
+
audio_lengths,
|
| 334 |
+
)
|
| 335 |
+
):
|
| 336 |
+
if idx > 4:
|
| 337 |
+
break
|
| 338 |
+
|
| 339 |
+
mel_len = audio_len // self.gt_mel_transform.hop_length
|
| 340 |
+
|
| 341 |
+
image_mels = plot_mel(
|
| 342 |
+
[
|
| 343 |
+
gt_mel[:, :mel_len],
|
| 344 |
+
gen_aux_mel[:, :mel_len],
|
| 345 |
+
],
|
| 346 |
+
[
|
| 347 |
+
"Ground-Truth",
|
| 348 |
+
"Auxiliary",
|
| 349 |
+
],
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if isinstance(self.logger, WandbLogger):
|
| 353 |
+
self.logger.experiment.log(
|
| 354 |
+
{
|
| 355 |
+
"reconstruction_mel": wandb.Image(image_mels, caption="mels"),
|
| 356 |
+
"wavs": [
|
| 357 |
+
wandb.Audio(
|
| 358 |
+
audio[0, :audio_len],
|
| 359 |
+
sample_rate=self.sampling_rate,
|
| 360 |
+
caption="gt",
|
| 361 |
+
),
|
| 362 |
+
wandb.Audio(
|
| 363 |
+
gen_aux_audio[0, :audio_len],
|
| 364 |
+
sample_rate=self.sampling_rate,
|
| 365 |
+
caption="aux",
|
| 366 |
+
),
|
| 367 |
+
wandb.Audio(
|
| 368 |
+
recon_audio[0, :audio_len],
|
| 369 |
+
sample_rate=self.sampling_rate,
|
| 370 |
+
caption="recon",
|
| 371 |
+
),
|
| 372 |
+
],
|
| 373 |
+
},
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if isinstance(self.logger, TensorBoardLogger):
|
| 377 |
+
self.logger.experiment.add_figure(
|
| 378 |
+
f"sample-{idx}/mels",
|
| 379 |
+
image_mels,
|
| 380 |
+
global_step=self.global_step,
|
| 381 |
+
)
|
| 382 |
+
self.logger.experiment.add_audio(
|
| 383 |
+
f"sample-{idx}/wavs/gt",
|
| 384 |
+
audio[0, :audio_len],
|
| 385 |
+
self.global_step,
|
| 386 |
+
sample_rate=self.sampling_rate,
|
| 387 |
+
)
|
| 388 |
+
self.logger.experiment.add_audio(
|
| 389 |
+
f"sample-{idx}/wavs/gen",
|
| 390 |
+
gen_aux_audio[0, :audio_len],
|
| 391 |
+
self.global_step,
|
| 392 |
+
sample_rate=self.sampling_rate,
|
| 393 |
+
)
|
| 394 |
+
self.logger.experiment.add_audio(
|
| 395 |
+
f"sample-{idx}/wavs/recon",
|
| 396 |
+
recon_audio[0, :audio_len],
|
| 397 |
+
self.global_step,
|
| 398 |
+
sample_rate=self.sampling_rate,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
plt.close(image_mels)
|
| 402 |
+
|
| 403 |
+
def encode(self, audios, audio_lengths):
|
| 404 |
+
audios = audios.float()
|
| 405 |
+
|
| 406 |
+
mels = self.encode_mel_transform(audios)
|
| 407 |
+
mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
|
| 408 |
+
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
| 409 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
| 410 |
+
mels = mels * mel_masks_float_conv
|
| 411 |
+
|
| 412 |
+
# Encode
|
| 413 |
+
encoded_features = self.encoder(mels) * mel_masks_float_conv
|
| 414 |
+
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
|
| 415 |
+
|
| 416 |
+
return self.quantizer.encode(encoded_features), feature_lengths
|
| 417 |
+
|
| 418 |
+
def decode(self, indices, feature_lengths, return_audios=False):
|
| 419 |
+
factor = math.prod(self.quantizer.downsample_factor)
|
| 420 |
+
mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
|
| 421 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
| 422 |
+
|
| 423 |
+
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
| 424 |
+
z = (
|
| 425 |
+
z
|
| 426 |
+
+ self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
|
| 427 |
+
:, :, None
|
| 428 |
+
]
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
gen_mel = (
|
| 432 |
+
self.decoder(
|
| 433 |
+
torch.randn_like(z) * mel_masks_float_conv,
|
| 434 |
+
condition=z,
|
| 435 |
+
)
|
| 436 |
+
* mel_masks_float_conv
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if return_audios:
|
| 440 |
+
return self.vocoder(gen_mel)
|
| 441 |
+
|
| 442 |
+
return gen_mel
|
fish_speech/models/vqgan/modules/discriminator.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Discriminator(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
blocks = []
|
| 11 |
+
convs = [
|
| 12 |
+
(1, 64, (3, 9), 1, (1, 4)),
|
| 13 |
+
(64, 128, (3, 9), (1, 2), (1, 4)),
|
| 14 |
+
(128, 256, (3, 9), (1, 2), (1, 4)),
|
| 15 |
+
(256, 512, (3, 9), (1, 2), (1, 4)),
|
| 16 |
+
(512, 1024, (3, 3), 1, (1, 1)),
|
| 17 |
+
(1024, 1, (3, 3), 1, (1, 1)),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
|
| 21 |
+
convs
|
| 22 |
+
):
|
| 23 |
+
blocks.append(
|
| 24 |
+
weight_norm(
|
| 25 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
| 26 |
+
)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if idx != len(convs) - 1:
|
| 30 |
+
blocks.append(nn.SiLU(inplace=True))
|
| 31 |
+
|
| 32 |
+
self.blocks = nn.Sequential(*blocks)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return self.blocks(x[:, None])[:, 0]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
model = Discriminator()
|
| 40 |
+
print(sum(p.numel() for p in model.parameters()) / 1_000_000)
|
| 41 |
+
x = torch.randn(1, 128, 1024)
|
| 42 |
+
y = model(x)
|
| 43 |
+
print(y.shape)
|
| 44 |
+
print(y)
|
fish_speech/models/vqgan/modules/firefly.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A inference only version of the FireflyGAN model
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
from math import prod
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import Conv1d
|
| 12 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 13 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 18 |
+
classname = m.__class__.__name__
|
| 19 |
+
if classname.find("Conv") != -1:
|
| 20 |
+
m.weight.data.normal_(mean, std)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_padding(kernel_size, dilation=1):
|
| 24 |
+
return (kernel_size * dilation - dilation) // 2
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ResBlock1(torch.nn.Module):
|
| 28 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.convs1 = nn.ModuleList(
|
| 32 |
+
[
|
| 33 |
+
weight_norm(
|
| 34 |
+
Conv1d(
|
| 35 |
+
channels,
|
| 36 |
+
channels,
|
| 37 |
+
kernel_size,
|
| 38 |
+
1,
|
| 39 |
+
dilation=dilation[0],
|
| 40 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 41 |
+
)
|
| 42 |
+
),
|
| 43 |
+
weight_norm(
|
| 44 |
+
Conv1d(
|
| 45 |
+
channels,
|
| 46 |
+
channels,
|
| 47 |
+
kernel_size,
|
| 48 |
+
1,
|
| 49 |
+
dilation=dilation[1],
|
| 50 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 51 |
+
)
|
| 52 |
+
),
|
| 53 |
+
weight_norm(
|
| 54 |
+
Conv1d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size,
|
| 58 |
+
1,
|
| 59 |
+
dilation=dilation[2],
|
| 60 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 61 |
+
)
|
| 62 |
+
),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
self.convs1.apply(init_weights)
|
| 66 |
+
|
| 67 |
+
self.convs2 = nn.ModuleList(
|
| 68 |
+
[
|
| 69 |
+
weight_norm(
|
| 70 |
+
Conv1d(
|
| 71 |
+
channels,
|
| 72 |
+
channels,
|
| 73 |
+
kernel_size,
|
| 74 |
+
1,
|
| 75 |
+
dilation=1,
|
| 76 |
+
padding=get_padding(kernel_size, 1),
|
| 77 |
+
)
|
| 78 |
+
),
|
| 79 |
+
weight_norm(
|
| 80 |
+
Conv1d(
|
| 81 |
+
channels,
|
| 82 |
+
channels,
|
| 83 |
+
kernel_size,
|
| 84 |
+
1,
|
| 85 |
+
dilation=1,
|
| 86 |
+
padding=get_padding(kernel_size, 1),
|
| 87 |
+
)
|
| 88 |
+
),
|
| 89 |
+
weight_norm(
|
| 90 |
+
Conv1d(
|
| 91 |
+
channels,
|
| 92 |
+
channels,
|
| 93 |
+
kernel_size,
|
| 94 |
+
1,
|
| 95 |
+
dilation=1,
|
| 96 |
+
padding=get_padding(kernel_size, 1),
|
| 97 |
+
)
|
| 98 |
+
),
|
| 99 |
+
]
|
| 100 |
+
)
|
| 101 |
+
self.convs2.apply(init_weights)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 105 |
+
xt = F.silu(x)
|
| 106 |
+
xt = c1(xt)
|
| 107 |
+
xt = F.silu(xt)
|
| 108 |
+
xt = c2(xt)
|
| 109 |
+
x = xt + x
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
def remove_parametrizations(self):
|
| 113 |
+
for conv in self.convs1:
|
| 114 |
+
remove_parametrizations(conv, tensor_name="weight")
|
| 115 |
+
for conv in self.convs2:
|
| 116 |
+
remove_parametrizations(conv, tensor_name="weight")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ParralelBlock(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
channels: int,
|
| 123 |
+
kernel_sizes: tuple[int] = (3, 7, 11),
|
| 124 |
+
dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
assert len(kernel_sizes) == len(dilation_sizes)
|
| 129 |
+
|
| 130 |
+
self.blocks = nn.ModuleList()
|
| 131 |
+
for k, d in zip(kernel_sizes, dilation_sizes):
|
| 132 |
+
self.blocks.append(ResBlock1(channels, k, d))
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
| 136 |
+
|
| 137 |
+
def remove_parametrizations(self):
|
| 138 |
+
for block in self.blocks:
|
| 139 |
+
block.remove_parametrizations()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class HiFiGANGenerator(nn.Module):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
*,
|
| 146 |
+
hop_length: int = 512,
|
| 147 |
+
upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
|
| 148 |
+
upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
|
| 149 |
+
resblock_kernel_sizes: tuple[int] = (3, 7, 11),
|
| 150 |
+
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
| 151 |
+
num_mels: int = 128,
|
| 152 |
+
upsample_initial_channel: int = 512,
|
| 153 |
+
use_template: bool = True,
|
| 154 |
+
pre_conv_kernel_size: int = 7,
|
| 155 |
+
post_conv_kernel_size: int = 7,
|
| 156 |
+
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
assert (
|
| 161 |
+
prod(upsample_rates) == hop_length
|
| 162 |
+
), f"hop_length must be {prod(upsample_rates)}"
|
| 163 |
+
|
| 164 |
+
self.conv_pre = weight_norm(
|
| 165 |
+
nn.Conv1d(
|
| 166 |
+
num_mels,
|
| 167 |
+
upsample_initial_channel,
|
| 168 |
+
pre_conv_kernel_size,
|
| 169 |
+
1,
|
| 170 |
+
padding=get_padding(pre_conv_kernel_size),
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.num_upsamples = len(upsample_rates)
|
| 175 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 176 |
+
|
| 177 |
+
self.noise_convs = nn.ModuleList()
|
| 178 |
+
self.use_template = use_template
|
| 179 |
+
self.ups = nn.ModuleList()
|
| 180 |
+
|
| 181 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 182 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
| 183 |
+
self.ups.append(
|
| 184 |
+
weight_norm(
|
| 185 |
+
nn.ConvTranspose1d(
|
| 186 |
+
upsample_initial_channel // (2**i),
|
| 187 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 188 |
+
k,
|
| 189 |
+
u,
|
| 190 |
+
padding=(k - u) // 2,
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if not use_template:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
if i + 1 < len(upsample_rates):
|
| 199 |
+
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
| 200 |
+
self.noise_convs.append(
|
| 201 |
+
Conv1d(
|
| 202 |
+
1,
|
| 203 |
+
c_cur,
|
| 204 |
+
kernel_size=stride_f0 * 2,
|
| 205 |
+
stride=stride_f0,
|
| 206 |
+
padding=stride_f0 // 2,
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
| 211 |
+
|
| 212 |
+
self.resblocks = nn.ModuleList()
|
| 213 |
+
for i in range(len(self.ups)):
|
| 214 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 215 |
+
self.resblocks.append(
|
| 216 |
+
ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.activation_post = post_activation()
|
| 220 |
+
self.conv_post = weight_norm(
|
| 221 |
+
nn.Conv1d(
|
| 222 |
+
ch,
|
| 223 |
+
1,
|
| 224 |
+
post_conv_kernel_size,
|
| 225 |
+
1,
|
| 226 |
+
padding=get_padding(post_conv_kernel_size),
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
self.ups.apply(init_weights)
|
| 230 |
+
self.conv_post.apply(init_weights)
|
| 231 |
+
|
| 232 |
+
def forward(self, x, template=None):
|
| 233 |
+
x = self.conv_pre(x)
|
| 234 |
+
|
| 235 |
+
for i in range(self.num_upsamples):
|
| 236 |
+
x = F.silu(x, inplace=True)
|
| 237 |
+
x = self.ups[i](x)
|
| 238 |
+
|
| 239 |
+
if self.use_template:
|
| 240 |
+
x = x + self.noise_convs[i](template)
|
| 241 |
+
|
| 242 |
+
if self.training and self.checkpointing:
|
| 243 |
+
x = checkpoint(
|
| 244 |
+
self.resblocks[i],
|
| 245 |
+
x,
|
| 246 |
+
use_reentrant=False,
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
x = self.resblocks[i](x)
|
| 250 |
+
|
| 251 |
+
x = self.activation_post(x)
|
| 252 |
+
x = self.conv_post(x)
|
| 253 |
+
x = torch.tanh(x)
|
| 254 |
+
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
def remove_parametrizations(self):
|
| 258 |
+
for up in self.ups:
|
| 259 |
+
remove_parametrizations(up, tensor_name="weight")
|
| 260 |
+
for block in self.resblocks:
|
| 261 |
+
block.remove_parametrizations()
|
| 262 |
+
remove_parametrizations(self.conv_pre, tensor_name="weight")
|
| 263 |
+
remove_parametrizations(self.conv_post, tensor_name="weight")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# DropPath copied from timm library
|
| 267 |
+
def drop_path(
|
| 268 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
| 269 |
+
):
|
| 270 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 271 |
+
|
| 272 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 273 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 274 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 275 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 276 |
+
'survival rate' as the argument.
|
| 277 |
+
|
| 278 |
+
""" # noqa: E501
|
| 279 |
+
|
| 280 |
+
if drop_prob == 0.0 or not training:
|
| 281 |
+
return x
|
| 282 |
+
keep_prob = 1 - drop_prob
|
| 283 |
+
shape = (x.shape[0],) + (1,) * (
|
| 284 |
+
x.ndim - 1
|
| 285 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 286 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 287 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 288 |
+
random_tensor.div_(keep_prob)
|
| 289 |
+
return x * random_tensor
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class DropPath(nn.Module):
|
| 293 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
| 294 |
+
|
| 295 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 296 |
+
super(DropPath, self).__init__()
|
| 297 |
+
self.drop_prob = drop_prob
|
| 298 |
+
self.scale_by_keep = scale_by_keep
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 302 |
+
|
| 303 |
+
def extra_repr(self):
|
| 304 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class LayerNorm(nn.Module):
|
| 308 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 309 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 310 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 311 |
+
with shape (batch_size, channels, height, width).
|
| 312 |
+
""" # noqa: E501
|
| 313 |
+
|
| 314 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 317 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 318 |
+
self.eps = eps
|
| 319 |
+
self.data_format = data_format
|
| 320 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 321 |
+
raise NotImplementedError
|
| 322 |
+
self.normalized_shape = (normalized_shape,)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
if self.data_format == "channels_last":
|
| 326 |
+
return F.layer_norm(
|
| 327 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 328 |
+
)
|
| 329 |
+
elif self.data_format == "channels_first":
|
| 330 |
+
u = x.mean(1, keepdim=True)
|
| 331 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 332 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 333 |
+
x = self.weight[:, None] * x + self.bias[:, None]
|
| 334 |
+
return x
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
|
| 338 |
+
class ConvNeXtBlock(nn.Module):
|
| 339 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 340 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 341 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 342 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
dim (int): Number of input channels.
|
| 346 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 347 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 348 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
| 349 |
+
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
| 350 |
+
dilation (int): Dilation for depthwise conv. Default: 1.
|
| 351 |
+
""" # noqa: E501
|
| 352 |
+
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
dim: int,
|
| 356 |
+
drop_path: float = 0.0,
|
| 357 |
+
layer_scale_init_value: float = 1e-6,
|
| 358 |
+
mlp_ratio: float = 4.0,
|
| 359 |
+
kernel_size: int = 7,
|
| 360 |
+
dilation: int = 1,
|
| 361 |
+
):
|
| 362 |
+
super().__init__()
|
| 363 |
+
|
| 364 |
+
self.dwconv = nn.Conv1d(
|
| 365 |
+
dim,
|
| 366 |
+
dim,
|
| 367 |
+
kernel_size=kernel_size,
|
| 368 |
+
padding=int(dilation * (kernel_size - 1) / 2),
|
| 369 |
+
groups=dim,
|
| 370 |
+
) # depthwise conv
|
| 371 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 372 |
+
self.pwconv1 = nn.Linear(
|
| 373 |
+
dim, int(mlp_ratio * dim)
|
| 374 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 375 |
+
self.act = nn.GELU()
|
| 376 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
| 377 |
+
self.gamma = (
|
| 378 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 379 |
+
if layer_scale_init_value > 0
|
| 380 |
+
else None
|
| 381 |
+
)
|
| 382 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 383 |
+
|
| 384 |
+
def forward(self, x, apply_residual: bool = True):
|
| 385 |
+
input = x
|
| 386 |
+
|
| 387 |
+
x = self.dwconv(x)
|
| 388 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
| 389 |
+
x = self.norm(x)
|
| 390 |
+
x = self.pwconv1(x)
|
| 391 |
+
x = self.act(x)
|
| 392 |
+
x = self.pwconv2(x)
|
| 393 |
+
|
| 394 |
+
if self.gamma is not None:
|
| 395 |
+
x = self.gamma * x
|
| 396 |
+
|
| 397 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
| 398 |
+
x = self.drop_path(x)
|
| 399 |
+
|
| 400 |
+
if apply_residual:
|
| 401 |
+
x = input + x
|
| 402 |
+
|
| 403 |
+
return x
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class ConvNeXtEncoder(nn.Module):
|
| 407 |
+
def __init__(
|
| 408 |
+
self,
|
| 409 |
+
input_channels: int = 3,
|
| 410 |
+
depths: list[int] = [3, 3, 9, 3],
|
| 411 |
+
dims: list[int] = [96, 192, 384, 768],
|
| 412 |
+
drop_path_rate: float = 0.0,
|
| 413 |
+
layer_scale_init_value: float = 1e-6,
|
| 414 |
+
kernel_size: int = 7,
|
| 415 |
+
):
|
| 416 |
+
super().__init__()
|
| 417 |
+
assert len(depths) == len(dims)
|
| 418 |
+
|
| 419 |
+
self.downsample_layers = nn.ModuleList()
|
| 420 |
+
stem = nn.Sequential(
|
| 421 |
+
nn.Conv1d(
|
| 422 |
+
input_channels,
|
| 423 |
+
dims[0],
|
| 424 |
+
kernel_size=kernel_size,
|
| 425 |
+
padding=kernel_size // 2,
|
| 426 |
+
padding_mode="zeros",
|
| 427 |
+
),
|
| 428 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
| 429 |
+
)
|
| 430 |
+
self.downsample_layers.append(stem)
|
| 431 |
+
|
| 432 |
+
for i in range(len(depths) - 1):
|
| 433 |
+
mid_layer = nn.Sequential(
|
| 434 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 435 |
+
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
| 436 |
+
)
|
| 437 |
+
self.downsample_layers.append(mid_layer)
|
| 438 |
+
|
| 439 |
+
self.stages = nn.ModuleList()
|
| 440 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 441 |
+
|
| 442 |
+
cur = 0
|
| 443 |
+
for i in range(len(depths)):
|
| 444 |
+
stage = nn.Sequential(
|
| 445 |
+
*[
|
| 446 |
+
ConvNeXtBlock(
|
| 447 |
+
dim=dims[i],
|
| 448 |
+
drop_path=dp_rates[cur + j],
|
| 449 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 450 |
+
kernel_size=kernel_size,
|
| 451 |
+
)
|
| 452 |
+
for j in range(depths[i])
|
| 453 |
+
]
|
| 454 |
+
)
|
| 455 |
+
self.stages.append(stage)
|
| 456 |
+
cur += depths[i]
|
| 457 |
+
|
| 458 |
+
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
| 459 |
+
self.apply(self._init_weights)
|
| 460 |
+
|
| 461 |
+
def _init_weights(self, m):
|
| 462 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 463 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 464 |
+
nn.init.constant_(m.bias, 0)
|
| 465 |
+
|
| 466 |
+
def forward(
|
| 467 |
+
self,
|
| 468 |
+
x: torch.Tensor,
|
| 469 |
+
) -> torch.Tensor:
|
| 470 |
+
for i in range(len(self.downsample_layers)):
|
| 471 |
+
x = self.downsample_layers[i](x)
|
| 472 |
+
x = self.stages[i](x)
|
| 473 |
+
|
| 474 |
+
return self.norm(x)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class FireflyBase(nn.Module):
|
| 478 |
+
def __init__(self, ckpt_path: str = None, pretrained: bool = True):
|
| 479 |
+
super().__init__()
|
| 480 |
+
|
| 481 |
+
self.backbone = ConvNeXtEncoder(
|
| 482 |
+
input_channels=128,
|
| 483 |
+
depths=[3, 3, 9, 3],
|
| 484 |
+
dims=[128, 256, 384, 512],
|
| 485 |
+
drop_path_rate=0.2,
|
| 486 |
+
kernel_size=7,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
self.head = HiFiGANGenerator(
|
| 490 |
+
hop_length=512,
|
| 491 |
+
upsample_rates=[8, 8, 2, 2, 2],
|
| 492 |
+
upsample_kernel_sizes=[16, 16, 4, 4, 4],
|
| 493 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 494 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 495 |
+
num_mels=512,
|
| 496 |
+
upsample_initial_channel=512,
|
| 497 |
+
use_template=False,
|
| 498 |
+
pre_conv_kernel_size=13,
|
| 499 |
+
post_conv_kernel_size=13,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if ckpt_path is not None:
|
| 503 |
+
self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
|
| 504 |
+
elif pretrained:
|
| 505 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 506 |
+
"https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
| 507 |
+
map_location="cpu",
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if "state_dict" in state_dict:
|
| 511 |
+
state_dict = state_dict["state_dict"]
|
| 512 |
+
|
| 513 |
+
if any("generator." in k for k in state_dict):
|
| 514 |
+
state_dict = {
|
| 515 |
+
k.replace("generator.", ""): v
|
| 516 |
+
for k, v in state_dict.items()
|
| 517 |
+
if "generator." in k
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
self.load_state_dict(state_dict, strict=True)
|
| 521 |
+
self.head.remove_parametrizations()
|
| 522 |
+
|
| 523 |
+
@torch.no_grad()
|
| 524 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 525 |
+
x = self.backbone(x)
|
| 526 |
+
x = self.head(x)
|
| 527 |
+
if x.ndim == 2:
|
| 528 |
+
x = x[:, None, :]
|
| 529 |
+
return x
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
if __name__ == "__main__":
|
| 533 |
+
model = FireflyBase()
|
| 534 |
+
model.eval()
|
| 535 |
+
x = torch.randn(1, 128, 128)
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
y = model(x)
|
| 538 |
+
print(y.shape)
|
fish_speech/models/vqgan/modules/fsq.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
| 8 |
+
|
| 9 |
+
from .firefly import ConvNeXtBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class FSQResult:
|
| 14 |
+
z: torch.Tensor
|
| 15 |
+
codes: torch.Tensor
|
| 16 |
+
latents: torch.Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DownsampleFiniteScalarQuantize(nn.Module):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
input_dim: int = 512,
|
| 23 |
+
n_codebooks: int = 9,
|
| 24 |
+
n_groups: int = 1,
|
| 25 |
+
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
| 26 |
+
downsample_factor: tuple[int] = (2, 2),
|
| 27 |
+
downsample_dims: tuple[int] | None = None,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
if downsample_dims is None:
|
| 32 |
+
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
| 33 |
+
|
| 34 |
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
| 35 |
+
|
| 36 |
+
self.residual_fsq = GroupedResidualFSQ(
|
| 37 |
+
dim=all_dims[-1],
|
| 38 |
+
levels=levels,
|
| 39 |
+
num_quantizers=n_codebooks,
|
| 40 |
+
groups=n_groups,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.downsample_factor = downsample_factor
|
| 44 |
+
self.downsample_dims = downsample_dims
|
| 45 |
+
|
| 46 |
+
self.downsample = nn.Sequential(
|
| 47 |
+
*[
|
| 48 |
+
nn.Sequential(
|
| 49 |
+
nn.Conv1d(
|
| 50 |
+
all_dims[idx],
|
| 51 |
+
all_dims[idx + 1],
|
| 52 |
+
kernel_size=factor,
|
| 53 |
+
stride=factor,
|
| 54 |
+
),
|
| 55 |
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
| 56 |
+
)
|
| 57 |
+
for idx, factor in enumerate(downsample_factor)
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.upsample = nn.Sequential(
|
| 62 |
+
*[
|
| 63 |
+
nn.Sequential(
|
| 64 |
+
nn.ConvTranspose1d(
|
| 65 |
+
all_dims[idx + 1],
|
| 66 |
+
all_dims[idx],
|
| 67 |
+
kernel_size=factor,
|
| 68 |
+
stride=factor,
|
| 69 |
+
),
|
| 70 |
+
ConvNeXtBlock(dim=all_dims[idx]),
|
| 71 |
+
)
|
| 72 |
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.apply(self._init_weights)
|
| 77 |
+
|
| 78 |
+
def _init_weights(self, m):
|
| 79 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 80 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 81 |
+
nn.init.constant_(m.bias, 0)
|
| 82 |
+
|
| 83 |
+
def forward(self, z) -> FSQResult:
|
| 84 |
+
original_shape = z.shape
|
| 85 |
+
z = self.downsample(z)
|
| 86 |
+
quantized, indices = self.residual_fsq(z.mT)
|
| 87 |
+
result = FSQResult(
|
| 88 |
+
z=quantized.mT,
|
| 89 |
+
codes=indices.mT,
|
| 90 |
+
latents=z,
|
| 91 |
+
)
|
| 92 |
+
result.z = self.upsample(result.z)
|
| 93 |
+
|
| 94 |
+
# Pad or crop z to match original shape
|
| 95 |
+
diff = original_shape[-1] - result.z.shape[-1]
|
| 96 |
+
left = diff // 2
|
| 97 |
+
right = diff - left
|
| 98 |
+
|
| 99 |
+
if diff > 0:
|
| 100 |
+
result.z = F.pad(result.z, (left, right))
|
| 101 |
+
elif diff < 0:
|
| 102 |
+
result.z = result.z[..., left:-right]
|
| 103 |
+
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
def encode(self, z):
|
| 107 |
+
z = self.downsample(z)
|
| 108 |
+
_, indices = self.residual_fsq(z.mT)
|
| 109 |
+
indices = rearrange(indices, "g b l r -> b (g r) l")
|
| 110 |
+
return indices
|
| 111 |
+
|
| 112 |
+
def decode(self, indices: torch.Tensor):
|
| 113 |
+
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
| 114 |
+
z_q = self.residual_fsq.get_output_from_indices(indices)
|
| 115 |
+
z_q = self.upsample(z_q.mT)
|
| 116 |
+
return z_q
|
| 117 |
+
|
| 118 |
+
# def from_latents(self, latents: torch.Tensor):
|
| 119 |
+
# z_q, z_p, codes = super().from_latents(latents)
|
| 120 |
+
# z_q = self.upsample(z_q)
|
| 121 |
+
# return z_q, z_p, codes
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
rvq = DownsampleFiniteScalarQuantize(
|
| 126 |
+
n_codebooks=1,
|
| 127 |
+
downsample_factor=(2, 2),
|
| 128 |
+
)
|
| 129 |
+
x = torch.randn(16, 512, 80)
|
| 130 |
+
|
| 131 |
+
result = rvq(x)
|
| 132 |
+
print(rvq)
|
| 133 |
+
print(result.latents.shape, result.codes.shape, result.z.shape)
|
| 134 |
+
|
| 135 |
+
# y = rvq.from_codes(result.codes)
|
| 136 |
+
# print(y[0].shape)
|
| 137 |
+
|
| 138 |
+
# y = rvq.from_latents(result.latents)
|
| 139 |
+
# print(y[0].shape)
|
fish_speech/models/vqgan/modules/reference.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .wavenet import WaveNet
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ReferenceEncoder(WaveNet):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
input_channels: Optional[int] = None,
|
| 14 |
+
output_channels: Optional[int] = None,
|
| 15 |
+
residual_channels: int = 512,
|
| 16 |
+
residual_layers: int = 20,
|
| 17 |
+
dilation_cycle: Optional[int] = 4,
|
| 18 |
+
num_heads: int = 8,
|
| 19 |
+
latent_len: int = 4,
|
| 20 |
+
):
|
| 21 |
+
super().__init__(
|
| 22 |
+
input_channels=input_channels,
|
| 23 |
+
residual_channels=residual_channels,
|
| 24 |
+
residual_layers=residual_layers,
|
| 25 |
+
dilation_cycle=dilation_cycle,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.head_dim = residual_channels // num_heads
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
|
| 31 |
+
self.latent_len = latent_len
|
| 32 |
+
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
|
| 33 |
+
|
| 34 |
+
self.q = nn.Linear(residual_channels, residual_channels, bias=True)
|
| 35 |
+
self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
|
| 36 |
+
self.q_norm = nn.LayerNorm(self.head_dim)
|
| 37 |
+
self.k_norm = nn.LayerNorm(self.head_dim)
|
| 38 |
+
self.proj = nn.Linear(residual_channels, residual_channels)
|
| 39 |
+
self.proj_drop = nn.Dropout(0.1)
|
| 40 |
+
|
| 41 |
+
self.norm = nn.LayerNorm(residual_channels)
|
| 42 |
+
self.mlp = nn.Sequential(
|
| 43 |
+
nn.Linear(residual_channels, residual_channels * 4),
|
| 44 |
+
nn.SiLU(),
|
| 45 |
+
nn.Linear(residual_channels * 4, residual_channels),
|
| 46 |
+
)
|
| 47 |
+
self.output_projection_attn = nn.Linear(residual_channels, output_channels)
|
| 48 |
+
|
| 49 |
+
torch.nn.init.trunc_normal_(self.latent, std=0.02)
|
| 50 |
+
self.apply(self.init_weights)
|
| 51 |
+
|
| 52 |
+
def init_weights(self, m):
|
| 53 |
+
if isinstance(m, nn.Linear):
|
| 54 |
+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
| 55 |
+
if m.bias is not None:
|
| 56 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 57 |
+
|
| 58 |
+
def forward(self, x, attn_mask=None):
|
| 59 |
+
x = super().forward(x).mT
|
| 60 |
+
B, N, C = x.shape
|
| 61 |
+
|
| 62 |
+
# Calculate mask
|
| 63 |
+
if attn_mask is not None:
|
| 64 |
+
assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
|
| 65 |
+
|
| 66 |
+
attn_mask = attn_mask[:, None, None, :].expand(
|
| 67 |
+
B, self.num_heads, self.latent_len, N
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
q_latent = self.latent.expand(B, -1, -1)
|
| 71 |
+
q = (
|
| 72 |
+
self.q(q_latent)
|
| 73 |
+
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
| 74 |
+
.transpose(1, 2)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
kv = (
|
| 78 |
+
self.kv(x)
|
| 79 |
+
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
| 80 |
+
.permute(2, 0, 3, 1, 4)
|
| 81 |
+
)
|
| 82 |
+
k, v = kv.unbind(0)
|
| 83 |
+
|
| 84 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 85 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 86 |
+
|
| 87 |
+
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
| 88 |
+
x = self.proj(x)
|
| 89 |
+
x = self.proj_drop(x)
|
| 90 |
+
|
| 91 |
+
x = x + self.mlp(self.norm(x))
|
| 92 |
+
x = self.output_projection_attn(x)
|
| 93 |
+
x = x.mean(1)
|
| 94 |
+
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 100 |
+
model = ReferenceEncoder(
|
| 101 |
+
input_channels=128,
|
| 102 |
+
output_channels=64,
|
| 103 |
+
residual_channels=384,
|
| 104 |
+
residual_layers=20,
|
| 105 |
+
dilation_cycle=4,
|
| 106 |
+
num_heads=8,
|
| 107 |
+
)
|
| 108 |
+
x = torch.randn(4, 128, 64)
|
| 109 |
+
mask = torch.ones(4, 64, dtype=torch.bool)
|
| 110 |
+
y = model(x, mask)
|
| 111 |
+
print(y.shape)
|
| 112 |
+
loss = F.mse_loss(y, torch.randn(4, 64))
|
| 113 |
+
loss.backward()
|
fish_speech/models/vqgan/modules/wavenet.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Mish(nn.Module):
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return x * torch.tanh(F.softplus(x))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DiffusionEmbedding(nn.Module):
|
| 15 |
+
"""Diffusion Step Embedding"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, d_denoiser):
|
| 18 |
+
super(DiffusionEmbedding, self).__init__()
|
| 19 |
+
self.dim = d_denoiser
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
device = x.device
|
| 23 |
+
half_dim = self.dim // 2
|
| 24 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 25 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 26 |
+
emb = x[:, None] * emb[None, :]
|
| 27 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 28 |
+
return emb
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LinearNorm(nn.Module):
|
| 32 |
+
"""LinearNorm Projection"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_features, out_features, bias=False):
|
| 35 |
+
super(LinearNorm, self).__init__()
|
| 36 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
| 37 |
+
|
| 38 |
+
nn.init.xavier_uniform_(self.linear.weight)
|
| 39 |
+
if bias:
|
| 40 |
+
nn.init.constant_(self.linear.bias, 0.0)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x = self.linear(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ConvNorm(nn.Module):
|
| 48 |
+
"""1D Convolution"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
in_channels,
|
| 53 |
+
out_channels,
|
| 54 |
+
kernel_size=1,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=None,
|
| 57 |
+
dilation=1,
|
| 58 |
+
bias=True,
|
| 59 |
+
w_init_gain="linear",
|
| 60 |
+
):
|
| 61 |
+
super(ConvNorm, self).__init__()
|
| 62 |
+
|
| 63 |
+
if padding is None:
|
| 64 |
+
assert kernel_size % 2 == 1
|
| 65 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 66 |
+
|
| 67 |
+
self.conv = nn.Conv1d(
|
| 68 |
+
in_channels,
|
| 69 |
+
out_channels,
|
| 70 |
+
kernel_size=kernel_size,
|
| 71 |
+
stride=stride,
|
| 72 |
+
padding=padding,
|
| 73 |
+
dilation=dilation,
|
| 74 |
+
bias=bias,
|
| 75 |
+
)
|
| 76 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
| 77 |
+
|
| 78 |
+
def forward(self, signal):
|
| 79 |
+
conv_signal = self.conv(signal)
|
| 80 |
+
|
| 81 |
+
return conv_signal
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ResidualBlock(nn.Module):
|
| 85 |
+
"""Residual Block"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
residual_channels,
|
| 90 |
+
use_linear_bias=False,
|
| 91 |
+
dilation=1,
|
| 92 |
+
condition_channels=None,
|
| 93 |
+
):
|
| 94 |
+
super(ResidualBlock, self).__init__()
|
| 95 |
+
self.conv_layer = ConvNorm(
|
| 96 |
+
residual_channels,
|
| 97 |
+
2 * residual_channels,
|
| 98 |
+
kernel_size=3,
|
| 99 |
+
stride=1,
|
| 100 |
+
padding=dilation,
|
| 101 |
+
dilation=dilation,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if condition_channels is not None:
|
| 105 |
+
self.diffusion_projection = LinearNorm(
|
| 106 |
+
residual_channels, residual_channels, use_linear_bias
|
| 107 |
+
)
|
| 108 |
+
self.condition_projection = ConvNorm(
|
| 109 |
+
condition_channels, 2 * residual_channels, kernel_size=1
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.output_projection = ConvNorm(
|
| 113 |
+
residual_channels, 2 * residual_channels, kernel_size=1
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, condition=None, diffusion_step=None):
|
| 117 |
+
y = x
|
| 118 |
+
|
| 119 |
+
if diffusion_step is not None:
|
| 120 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
| 121 |
+
y = y + diffusion_step
|
| 122 |
+
|
| 123 |
+
y = self.conv_layer(y)
|
| 124 |
+
|
| 125 |
+
if condition is not None:
|
| 126 |
+
condition = self.condition_projection(condition)
|
| 127 |
+
y = y + condition
|
| 128 |
+
|
| 129 |
+
gate, filter = torch.chunk(y, 2, dim=1)
|
| 130 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
| 131 |
+
|
| 132 |
+
y = self.output_projection(y)
|
| 133 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
| 134 |
+
|
| 135 |
+
return (x + residual) / math.sqrt(2.0), skip
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class WaveNet(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
input_channels: Optional[int] = None,
|
| 142 |
+
output_channels: Optional[int] = None,
|
| 143 |
+
residual_channels: int = 512,
|
| 144 |
+
residual_layers: int = 20,
|
| 145 |
+
dilation_cycle: Optional[int] = 4,
|
| 146 |
+
is_diffusion: bool = False,
|
| 147 |
+
condition_channels: Optional[int] = None,
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
# Input projection
|
| 152 |
+
self.input_projection = None
|
| 153 |
+
if input_channels is not None and input_channels != residual_channels:
|
| 154 |
+
self.input_projection = ConvNorm(
|
| 155 |
+
input_channels, residual_channels, kernel_size=1
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if input_channels is None:
|
| 159 |
+
input_channels = residual_channels
|
| 160 |
+
|
| 161 |
+
self.input_channels = input_channels
|
| 162 |
+
|
| 163 |
+
# Residual layers
|
| 164 |
+
self.residual_layers = nn.ModuleList(
|
| 165 |
+
[
|
| 166 |
+
ResidualBlock(
|
| 167 |
+
residual_channels=residual_channels,
|
| 168 |
+
use_linear_bias=False,
|
| 169 |
+
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
|
| 170 |
+
condition_channels=condition_channels,
|
| 171 |
+
)
|
| 172 |
+
for i in range(residual_layers)
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Skip projection
|
| 177 |
+
self.skip_projection = ConvNorm(
|
| 178 |
+
residual_channels, residual_channels, kernel_size=1
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Output projection
|
| 182 |
+
self.output_projection = None
|
| 183 |
+
if output_channels is not None and output_channels != residual_channels:
|
| 184 |
+
self.output_projection = ConvNorm(
|
| 185 |
+
residual_channels, output_channels, kernel_size=1
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if is_diffusion:
|
| 189 |
+
self.diffusion_embedding = DiffusionEmbedding(residual_channels)
|
| 190 |
+
self.mlp = nn.Sequential(
|
| 191 |
+
LinearNorm(residual_channels, residual_channels * 4, False),
|
| 192 |
+
Mish(),
|
| 193 |
+
LinearNorm(residual_channels * 4, residual_channels, False),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.apply(self._init_weights)
|
| 197 |
+
|
| 198 |
+
def _init_weights(self, m):
|
| 199 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 200 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 201 |
+
if getattr(m, "bias", None) is not None:
|
| 202 |
+
nn.init.constant_(m.bias, 0)
|
| 203 |
+
|
| 204 |
+
def forward(self, x, t=None, condition=None):
|
| 205 |
+
if self.input_projection is not None:
|
| 206 |
+
x = self.input_projection(x)
|
| 207 |
+
x = F.silu(x)
|
| 208 |
+
|
| 209 |
+
if t is not None:
|
| 210 |
+
t = self.diffusion_embedding(t)
|
| 211 |
+
t = self.mlp(t)
|
| 212 |
+
|
| 213 |
+
skip = []
|
| 214 |
+
for layer in self.residual_layers:
|
| 215 |
+
x, skip_connection = layer(x, condition, t)
|
| 216 |
+
skip.append(skip_connection)
|
| 217 |
+
|
| 218 |
+
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
|
| 219 |
+
x = self.skip_projection(x)
|
| 220 |
+
|
| 221 |
+
if self.output_projection is not None:
|
| 222 |
+
x = F.silu(x)
|
| 223 |
+
x = self.output_projection(x)
|
| 224 |
+
|
| 225 |
+
return x
|
fish_speech/models/vqgan/spectrogram.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio.functional as F
|
| 3 |
+
from torch import Tensor, nn
|
| 4 |
+
from torchaudio.transforms import MelScale
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LinearSpectrogram(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
n_fft=2048,
|
| 11 |
+
win_length=2048,
|
| 12 |
+
hop_length=512,
|
| 13 |
+
center=False,
|
| 14 |
+
mode="pow2_sqrt",
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.n_fft = n_fft
|
| 19 |
+
self.win_length = win_length
|
| 20 |
+
self.hop_length = hop_length
|
| 21 |
+
self.center = center
|
| 22 |
+
self.mode = mode
|
| 23 |
+
|
| 24 |
+
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
|
| 25 |
+
|
| 26 |
+
def forward(self, y: Tensor) -> Tensor:
|
| 27 |
+
if y.ndim == 3:
|
| 28 |
+
y = y.squeeze(1)
|
| 29 |
+
|
| 30 |
+
y = torch.nn.functional.pad(
|
| 31 |
+
y.unsqueeze(1),
|
| 32 |
+
(
|
| 33 |
+
(self.win_length - self.hop_length) // 2,
|
| 34 |
+
(self.win_length - self.hop_length + 1) // 2,
|
| 35 |
+
),
|
| 36 |
+
mode="reflect",
|
| 37 |
+
).squeeze(1)
|
| 38 |
+
|
| 39 |
+
spec = torch.stft(
|
| 40 |
+
y,
|
| 41 |
+
self.n_fft,
|
| 42 |
+
hop_length=self.hop_length,
|
| 43 |
+
win_length=self.win_length,
|
| 44 |
+
window=self.window,
|
| 45 |
+
center=self.center,
|
| 46 |
+
pad_mode="reflect",
|
| 47 |
+
normalized=False,
|
| 48 |
+
onesided=True,
|
| 49 |
+
return_complex=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
spec = torch.view_as_real(spec)
|
| 53 |
+
|
| 54 |
+
if self.mode == "pow2_sqrt":
|
| 55 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 56 |
+
|
| 57 |
+
return spec
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LogMelSpectrogram(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
sample_rate=44100,
|
| 64 |
+
n_fft=2048,
|
| 65 |
+
win_length=2048,
|
| 66 |
+
hop_length=512,
|
| 67 |
+
n_mels=128,
|
| 68 |
+
center=False,
|
| 69 |
+
f_min=0.0,
|
| 70 |
+
f_max=None,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
self.sample_rate = sample_rate
|
| 75 |
+
self.n_fft = n_fft
|
| 76 |
+
self.win_length = win_length
|
| 77 |
+
self.hop_length = hop_length
|
| 78 |
+
self.center = center
|
| 79 |
+
self.n_mels = n_mels
|
| 80 |
+
self.f_min = f_min
|
| 81 |
+
self.f_max = f_max or float(sample_rate // 2)
|
| 82 |
+
|
| 83 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
| 84 |
+
|
| 85 |
+
fb = F.melscale_fbanks(
|
| 86 |
+
n_freqs=self.n_fft // 2 + 1,
|
| 87 |
+
f_min=self.f_min,
|
| 88 |
+
f_max=self.f_max,
|
| 89 |
+
n_mels=self.n_mels,
|
| 90 |
+
sample_rate=self.sample_rate,
|
| 91 |
+
norm="slaney",
|
| 92 |
+
mel_scale="slaney",
|
| 93 |
+
)
|
| 94 |
+
self.register_buffer(
|
| 95 |
+
"fb",
|
| 96 |
+
fb,
|
| 97 |
+
persistent=False,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def compress(self, x: Tensor) -> Tensor:
|
| 101 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
| 102 |
+
|
| 103 |
+
def decompress(self, x: Tensor) -> Tensor:
|
| 104 |
+
return torch.exp(x)
|
| 105 |
+
|
| 106 |
+
def apply_mel_scale(self, x: Tensor) -> Tensor:
|
| 107 |
+
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
|
| 108 |
+
|
| 109 |
+
def forward(
|
| 110 |
+
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
if sample_rate is not None and sample_rate != self.sample_rate:
|
| 113 |
+
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
|
| 114 |
+
|
| 115 |
+
linear = self.spectrogram(x)
|
| 116 |
+
x = self.apply_mel_scale(linear)
|
| 117 |
+
x = self.compress(x)
|
| 118 |
+
|
| 119 |
+
if return_linear:
|
| 120 |
+
return x, self.compress(linear)
|
| 121 |
+
|
| 122 |
+
return x
|
fish_speech/models/vqgan/utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import torch
|
| 3 |
+
from matplotlib import pyplot as plt
|
| 4 |
+
|
| 5 |
+
matplotlib.use("Agg")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def convert_pad_shape(pad_shape):
|
| 9 |
+
l = pad_shape[::-1]
|
| 10 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 11 |
+
return pad_shape
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sequence_mask(length, max_length=None):
|
| 15 |
+
if max_length is None:
|
| 16 |
+
max_length = length.max()
|
| 17 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 18 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 22 |
+
classname = m.__class__.__name__
|
| 23 |
+
if classname.find("Conv") != -1:
|
| 24 |
+
m.weight.data.normal_(mean, std)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_padding(kernel_size, dilation=1):
|
| 28 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def plot_mel(data, titles=None):
|
| 32 |
+
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
| 33 |
+
|
| 34 |
+
if titles is None:
|
| 35 |
+
titles = [None for i in range(len(data))]
|
| 36 |
+
|
| 37 |
+
plt.tight_layout()
|
| 38 |
+
|
| 39 |
+
for i in range(len(data)):
|
| 40 |
+
mel = data[i]
|
| 41 |
+
|
| 42 |
+
if isinstance(mel, torch.Tensor):
|
| 43 |
+
mel = mel.float().detach().cpu().numpy()
|
| 44 |
+
|
| 45 |
+
axes[i][0].imshow(mel, origin="lower")
|
| 46 |
+
axes[i][0].set_aspect(2.5, adjustable="box")
|
| 47 |
+
axes[i][0].set_ylim(0, mel.shape[0])
|
| 48 |
+
axes[i][0].set_title(titles[i], fontsize="medium")
|
| 49 |
+
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
| 50 |
+
axes[i][0].set_anchor("W")
|
| 51 |
+
|
| 52 |
+
return fig
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def slice_segments(x, ids_str, segment_size=4):
|
| 56 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
| 57 |
+
for i in range(x.size(0)):
|
| 58 |
+
idx_str = ids_str[i]
|
| 59 |
+
idx_end = idx_str + segment_size
|
| 60 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
| 61 |
+
|
| 62 |
+
return ret
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 66 |
+
b, d, t = x.size()
|
| 67 |
+
if x_lengths is None:
|
| 68 |
+
x_lengths = t
|
| 69 |
+
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
| 70 |
+
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
| 71 |
+
ret = slice_segments(x, ids_str, segment_size)
|
| 72 |
+
return ret, ids_str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@torch.jit.script
|
| 76 |
+
def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
|
| 77 |
+
n_channels_int = n_channels[0]
|
| 78 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 79 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 80 |
+
acts = t_act * s_act
|
| 81 |
+
|
| 82 |
+
return acts
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def avg_with_mask(x, mask):
|
| 86 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
| 87 |
+
|
| 88 |
+
if mask.ndim == 2:
|
| 89 |
+
mask = mask.unsqueeze(1)
|
| 90 |
+
|
| 91 |
+
if mask.shape[1] == 1:
|
| 92 |
+
mask = mask.expand_as(x)
|
| 93 |
+
|
| 94 |
+
return (x * mask).sum() / mask.sum()
|
fish_speech/scheduler.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_cosine_schedule_with_warmup_lr_lambda(
|
| 5 |
+
current_step: int,
|
| 6 |
+
*,
|
| 7 |
+
num_warmup_steps: int,
|
| 8 |
+
num_training_steps: int,
|
| 9 |
+
num_cycles: float = 0.5,
|
| 10 |
+
final_lr_ratio: float = 0.0,
|
| 11 |
+
):
|
| 12 |
+
if current_step < num_warmup_steps:
|
| 13 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 14 |
+
|
| 15 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 16 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
return max(
|
| 20 |
+
final_lr_ratio,
|
| 21 |
+
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
| 22 |
+
)
|
fish_speech/text/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .clean import clean_text
|
| 2 |
+
|
| 3 |
+
__all__ = ["clean_text"]
|
fish_speech/text/clean.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
LANGUAGE_UNICODE_RANGE_MAP = {
|
| 5 |
+
"ZH": [(0x4E00, 0x9FFF)],
|
| 6 |
+
"JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
|
| 7 |
+
"EN": [(0x0000, 0x007F)],
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
SYMBOLS_MAPPING = {
|
| 11 |
+
":": ",",
|
| 12 |
+
";": ",",
|
| 13 |
+
",": ",",
|
| 14 |
+
"。": ".",
|
| 15 |
+
"!": "!",
|
| 16 |
+
"?": "?",
|
| 17 |
+
"\n": ".",
|
| 18 |
+
"·": ",",
|
| 19 |
+
"、": ",",
|
| 20 |
+
"...": "…",
|
| 21 |
+
"$": ".",
|
| 22 |
+
"“": "'",
|
| 23 |
+
"”": "'",
|
| 24 |
+
"‘": "'",
|
| 25 |
+
"’": "'",
|
| 26 |
+
"(": "'",
|
| 27 |
+
")": "'",
|
| 28 |
+
"(": "'",
|
| 29 |
+
")": "'",
|
| 30 |
+
"《": "'",
|
| 31 |
+
"》": "'",
|
| 32 |
+
"【": "'",
|
| 33 |
+
"】": "'",
|
| 34 |
+
"[": "'",
|
| 35 |
+
"]": "'",
|
| 36 |
+
"—": "-",
|
| 37 |
+
"~": "-",
|
| 38 |
+
"~": "-",
|
| 39 |
+
"・": "-",
|
| 40 |
+
"「": "'",
|
| 41 |
+
"」": "'",
|
| 42 |
+
";": ",",
|
| 43 |
+
":": ",",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
REPLACE_SYMBOL_REGEX = re.compile(
|
| 47 |
+
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
| 48 |
+
)
|
| 49 |
+
ALL_KNOWN_UTF8_RANGE = list(
|
| 50 |
+
itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
|
| 51 |
+
)
|
| 52 |
+
REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
|
| 53 |
+
"[^"
|
| 54 |
+
+ "".join(
|
| 55 |
+
f"{re.escape(chr(start))}-{re.escape(chr(end))}"
|
| 56 |
+
for start, end in ALL_KNOWN_UTF8_RANGE
|
| 57 |
+
)
|
| 58 |
+
+ "]"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def clean_text(text):
|
| 63 |
+
# Clean the text
|
| 64 |
+
text = text.strip()
|
| 65 |
+
# Replace <p:(.*?)> with <PPP(.*?)PPP>
|
| 66 |
+
text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
|
| 67 |
+
# Replace all chinese symbols with their english counterparts
|
| 68 |
+
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
| 69 |
+
text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
| 70 |
+
# Replace <PPP(.*?)PPP> with <p:(.*?)>
|
| 71 |
+
text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
|
| 72 |
+
|
| 73 |
+
return text
|
fish_speech/train.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import hydra
|
| 5 |
+
import lightning as L
|
| 6 |
+
import pyrootutils
|
| 7 |
+
import torch
|
| 8 |
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
| 9 |
+
from lightning.pytorch.loggers import Logger
|
| 10 |
+
from omegaconf import DictConfig, OmegaConf
|
| 11 |
+
|
| 12 |
+
os.environ.pop("SLURM_NTASKS", None)
|
| 13 |
+
os.environ.pop("SLURM_JOB_NAME", None)
|
| 14 |
+
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
|
| 15 |
+
|
| 16 |
+
# register eval resolver and root
|
| 17 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 18 |
+
|
| 19 |
+
# Allow TF32 on Ampere GPUs
|
| 20 |
+
torch.set_float32_matmul_precision("high")
|
| 21 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 22 |
+
|
| 23 |
+
# register eval resolver
|
| 24 |
+
OmegaConf.register_new_resolver("eval", eval)
|
| 25 |
+
|
| 26 |
+
import fish_speech.utils as utils
|
| 27 |
+
|
| 28 |
+
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@utils.task_wrapper
|
| 32 |
+
def train(cfg: DictConfig) -> tuple[dict, dict]:
|
| 33 |
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
| 34 |
+
training.
|
| 35 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
| 36 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
| 37 |
+
Args:
|
| 38 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
| 41 |
+
""" # noqa: E501
|
| 42 |
+
|
| 43 |
+
# set seed for random number generators in pytorch, numpy and python.random
|
| 44 |
+
if cfg.get("seed"):
|
| 45 |
+
L.seed_everything(cfg.seed, workers=False)
|
| 46 |
+
|
| 47 |
+
if cfg.get("deterministic"):
|
| 48 |
+
torch.use_deterministic_algorithms(True)
|
| 49 |
+
|
| 50 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 51 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 52 |
+
|
| 53 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
| 54 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
| 55 |
+
|
| 56 |
+
log.info("Instantiating callbacks...")
|
| 57 |
+
callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
|
| 58 |
+
|
| 59 |
+
log.info("Instantiating loggers...")
|
| 60 |
+
logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
|
| 61 |
+
|
| 62 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 63 |
+
trainer: Trainer = hydra.utils.instantiate(
|
| 64 |
+
cfg.trainer, callbacks=callbacks, logger=logger
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
object_dict = {
|
| 68 |
+
"cfg": cfg,
|
| 69 |
+
"datamodule": datamodule,
|
| 70 |
+
"model": model,
|
| 71 |
+
"callbacks": callbacks,
|
| 72 |
+
"logger": logger,
|
| 73 |
+
"trainer": trainer,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
if logger:
|
| 77 |
+
log.info("Logging hyperparameters!")
|
| 78 |
+
utils.log_hyperparameters(object_dict)
|
| 79 |
+
|
| 80 |
+
if cfg.get("train"):
|
| 81 |
+
log.info("Starting training!")
|
| 82 |
+
|
| 83 |
+
ckpt_path = cfg.get("ckpt_path")
|
| 84 |
+
auto_resume = False
|
| 85 |
+
|
| 86 |
+
resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
|
| 87 |
+
if resume_ckpt_path is not None:
|
| 88 |
+
ckpt_path = resume_ckpt_path
|
| 89 |
+
auto_resume = True
|
| 90 |
+
|
| 91 |
+
if ckpt_path is not None:
|
| 92 |
+
log.info(f"Resuming from checkpoint: {ckpt_path}")
|
| 93 |
+
|
| 94 |
+
# resume weights only is disabled for auto-resume
|
| 95 |
+
if cfg.get("resume_weights_only") and auto_resume is False:
|
| 96 |
+
log.info("Resuming weights only!")
|
| 97 |
+
ckpt = torch.load(ckpt_path, map_location=model.device)
|
| 98 |
+
if "state_dict" in ckpt:
|
| 99 |
+
ckpt = ckpt["state_dict"]
|
| 100 |
+
err = model.load_state_dict(ckpt, strict=False)
|
| 101 |
+
log.info(f"Error loading state dict: {err}")
|
| 102 |
+
ckpt_path = None
|
| 103 |
+
|
| 104 |
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
| 105 |
+
|
| 106 |
+
train_metrics = trainer.callback_metrics
|
| 107 |
+
|
| 108 |
+
if cfg.get("test"):
|
| 109 |
+
log.info("Starting testing!")
|
| 110 |
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
| 111 |
+
if ckpt_path == "":
|
| 112 |
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
| 113 |
+
ckpt_path = cfg.get("ckpt_path")
|
| 114 |
+
|
| 115 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
| 116 |
+
log.info(f"Best ckpt path: {ckpt_path}")
|
| 117 |
+
|
| 118 |
+
test_metrics = trainer.callback_metrics
|
| 119 |
+
|
| 120 |
+
# merge train and test metrics
|
| 121 |
+
metric_dict = {**train_metrics, **test_metrics}
|
| 122 |
+
|
| 123 |
+
return metric_dict, object_dict
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@hydra.main(
|
| 127 |
+
version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
|
| 128 |
+
)
|
| 129 |
+
def main(cfg: DictConfig) -> Optional[float]:
|
| 130 |
+
# train the model
|
| 131 |
+
train(cfg)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
main()
|
fish_speech/utils/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .braceexpand import braceexpand
|
| 2 |
+
from .file import get_latest_checkpoint
|
| 3 |
+
from .instantiators import instantiate_callbacks, instantiate_loggers
|
| 4 |
+
from .logger import RankedLogger
|
| 5 |
+
from .logging_utils import log_hyperparameters
|
| 6 |
+
from .rich_utils import enforce_tags, print_config_tree
|
| 7 |
+
from .utils import extras, get_metric_value, task_wrapper
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"enforce_tags",
|
| 11 |
+
"extras",
|
| 12 |
+
"get_metric_value",
|
| 13 |
+
"RankedLogger",
|
| 14 |
+
"instantiate_callbacks",
|
| 15 |
+
"instantiate_loggers",
|
| 16 |
+
"log_hyperparameters",
|
| 17 |
+
"print_config_tree",
|
| 18 |
+
"task_wrapper",
|
| 19 |
+
"braceexpand",
|
| 20 |
+
"get_latest_checkpoint",
|
| 21 |
+
]
|
fish_speech/utils/braceexpand.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bash-style brace expansion
|
| 3 |
+
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
|
| 4 |
+
License: MIT
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import string
|
| 9 |
+
from itertools import chain, product
|
| 10 |
+
from typing import Iterable, Iterator, Optional
|
| 11 |
+
|
| 12 |
+
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class UnbalancedBracesError(ValueError):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
alphabet = string.ascii_uppercase + string.ascii_lowercase
|
| 20 |
+
|
| 21 |
+
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
|
| 22 |
+
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
|
| 23 |
+
escape_re = re.compile(r"\\(.)")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
|
| 27 |
+
"""braceexpand(pattern) -> iterator over generated strings
|
| 28 |
+
|
| 29 |
+
Returns an iterator over the strings resulting from brace expansion
|
| 30 |
+
of pattern. This function implements Brace Expansion as described in
|
| 31 |
+
bash(1), with the following limitations:
|
| 32 |
+
|
| 33 |
+
* A pattern containing unbalanced braces will raise an
|
| 34 |
+
UnbalancedBracesError exception. In bash, unbalanced braces will either
|
| 35 |
+
be partly expanded or ignored.
|
| 36 |
+
|
| 37 |
+
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
|
| 38 |
+
include the characters '[]^_`' between 'Z' and 'a'.
|
| 39 |
+
|
| 40 |
+
When escape is True (the default), characters in pattern can be
|
| 41 |
+
prefixed with a backslash to cause them not to be interpreted as
|
| 42 |
+
special characters for brace expansion (such as '{', '}', ',').
|
| 43 |
+
To pass through a a literal backslash, double it ('\\\\').
|
| 44 |
+
|
| 45 |
+
When escape is False, backslashes in pattern have no special
|
| 46 |
+
meaning and will be preserved in the output.
|
| 47 |
+
|
| 48 |
+
Examples:
|
| 49 |
+
|
| 50 |
+
>>> from braceexpand import braceexpand
|
| 51 |
+
|
| 52 |
+
# Integer range
|
| 53 |
+
>>> list(braceexpand('item{1..3}'))
|
| 54 |
+
['item1', 'item2', 'item3']
|
| 55 |
+
|
| 56 |
+
# Character range
|
| 57 |
+
>>> list(braceexpand('{a..c}'))
|
| 58 |
+
['a', 'b', 'c']
|
| 59 |
+
|
| 60 |
+
# Sequence
|
| 61 |
+
>>> list(braceexpand('index.html{,.backup}'))
|
| 62 |
+
['index.html', 'index.html.backup']
|
| 63 |
+
|
| 64 |
+
# Nested patterns
|
| 65 |
+
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
|
| 66 |
+
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
|
| 67 |
+
|
| 68 |
+
# Prefixing an integer with zero causes all numbers to be padded to
|
| 69 |
+
# the same width.
|
| 70 |
+
>>> list(braceexpand('{07..10}'))
|
| 71 |
+
['07', '08', '09', '10']
|
| 72 |
+
|
| 73 |
+
# An optional increment can be specified for ranges.
|
| 74 |
+
>>> list(braceexpand('{a..g..2}'))
|
| 75 |
+
['a', 'c', 'e', 'g']
|
| 76 |
+
|
| 77 |
+
# Ranges can go in both directions.
|
| 78 |
+
>>> list(braceexpand('{4..1}'))
|
| 79 |
+
['4', '3', '2', '1']
|
| 80 |
+
|
| 81 |
+
# Numbers can be negative
|
| 82 |
+
>>> list(braceexpand('{2..-1}'))
|
| 83 |
+
['2', '1', '0', '-1']
|
| 84 |
+
|
| 85 |
+
# Unbalanced braces raise an exception.
|
| 86 |
+
>>> list(braceexpand('{1{2,3}'))
|
| 87 |
+
Traceback (most recent call last):
|
| 88 |
+
...
|
| 89 |
+
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
|
| 90 |
+
|
| 91 |
+
# By default, the backslash is the escape character.
|
| 92 |
+
>>> list(braceexpand(r'{1\\{2,3}'))
|
| 93 |
+
['1{2', '3']
|
| 94 |
+
|
| 95 |
+
# Setting 'escape' to False disables backslash escaping.
|
| 96 |
+
>>> list(braceexpand(r'\\{1,2}', escape=False))
|
| 97 |
+
['\\\\1', '\\\\2']
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
return (
|
| 101 |
+
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
|
| 106 |
+
start = 0
|
| 107 |
+
pos = 0
|
| 108 |
+
bracketdepth = 0
|
| 109 |
+
items: list[Iterable[str]] = []
|
| 110 |
+
|
| 111 |
+
# print 'pattern:', pattern
|
| 112 |
+
while pos < len(pattern):
|
| 113 |
+
if escape and pattern[pos] == "\\":
|
| 114 |
+
pos += 2
|
| 115 |
+
continue
|
| 116 |
+
elif pattern[pos] == "{":
|
| 117 |
+
if bracketdepth == 0 and pos > start:
|
| 118 |
+
# print 'literal:', pattern[start:pos]
|
| 119 |
+
items.append([pattern[start:pos]])
|
| 120 |
+
start = pos
|
| 121 |
+
bracketdepth += 1
|
| 122 |
+
elif pattern[pos] == "}":
|
| 123 |
+
bracketdepth -= 1
|
| 124 |
+
if bracketdepth == 0:
|
| 125 |
+
# print 'expression:', pattern[start+1:pos]
|
| 126 |
+
expr = pattern[start + 1 : pos]
|
| 127 |
+
item = parse_expression(expr, escape)
|
| 128 |
+
if item is None: # not a range or sequence
|
| 129 |
+
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
|
| 130 |
+
else:
|
| 131 |
+
items.append(item)
|
| 132 |
+
start = pos + 1 # skip the closing brace
|
| 133 |
+
pos += 1
|
| 134 |
+
|
| 135 |
+
if bracketdepth != 0: # unbalanced braces
|
| 136 |
+
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
|
| 137 |
+
|
| 138 |
+
if start < pos:
|
| 139 |
+
items.append([pattern[start:]])
|
| 140 |
+
|
| 141 |
+
return ("".join(item) for item in product(*items))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
|
| 145 |
+
int_range_match = int_range_re.match(expr)
|
| 146 |
+
if int_range_match:
|
| 147 |
+
return make_int_range(*int_range_match.groups())
|
| 148 |
+
|
| 149 |
+
char_range_match = char_range_re.match(expr)
|
| 150 |
+
if char_range_match:
|
| 151 |
+
return make_char_range(*char_range_match.groups())
|
| 152 |
+
|
| 153 |
+
return parse_sequence(expr, escape)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
|
| 157 |
+
# sequence -> chain(*sequence_items)
|
| 158 |
+
start = 0
|
| 159 |
+
pos = 0
|
| 160 |
+
bracketdepth = 0
|
| 161 |
+
items: list[Iterable[str]] = []
|
| 162 |
+
|
| 163 |
+
# print 'sequence:', seq
|
| 164 |
+
while pos < len(seq):
|
| 165 |
+
if escape and seq[pos] == "\\":
|
| 166 |
+
pos += 2
|
| 167 |
+
continue
|
| 168 |
+
elif seq[pos] == "{":
|
| 169 |
+
bracketdepth += 1
|
| 170 |
+
elif seq[pos] == "}":
|
| 171 |
+
bracketdepth -= 1
|
| 172 |
+
elif seq[pos] == "," and bracketdepth == 0:
|
| 173 |
+
items.append(parse_pattern(seq[start:pos], escape))
|
| 174 |
+
start = pos + 1 # skip the comma
|
| 175 |
+
pos += 1
|
| 176 |
+
|
| 177 |
+
if bracketdepth != 0:
|
| 178 |
+
raise UnbalancedBracesError
|
| 179 |
+
if not items:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
# part after the last comma (may be the empty string)
|
| 183 |
+
items.append(parse_pattern(seq[start:], escape))
|
| 184 |
+
return chain(*items)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
|
| 188 |
+
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
|
| 189 |
+
padding = max(len(left), len(right))
|
| 190 |
+
else:
|
| 191 |
+
padding = 0
|
| 192 |
+
step = (int(incr) or 1) if incr else 1
|
| 193 |
+
start = int(left)
|
| 194 |
+
end = int(right)
|
| 195 |
+
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
|
| 196 |
+
fmt = "%0{}d".format(padding)
|
| 197 |
+
return (fmt % i for i in r)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
|
| 201 |
+
step = (int(incr) or 1) if incr else 1
|
| 202 |
+
start = alphabet.index(left)
|
| 203 |
+
end = alphabet.index(right)
|
| 204 |
+
if start < end:
|
| 205 |
+
return alphabet[start : end + 1 : step]
|
| 206 |
+
else:
|
| 207 |
+
end = end or -len(alphabet)
|
| 208 |
+
return alphabet[start : end - 1 : -step]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
import doctest
|
| 213 |
+
import sys
|
| 214 |
+
|
| 215 |
+
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
|
| 216 |
+
if failed:
|
| 217 |
+
sys.exit(1)
|
fish_speech/utils/file.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from natsort import natsorted
|
| 8 |
+
|
| 9 |
+
AUDIO_EXTENSIONS = {
|
| 10 |
+
".mp3",
|
| 11 |
+
".wav",
|
| 12 |
+
".flac",
|
| 13 |
+
".ogg",
|
| 14 |
+
".m4a",
|
| 15 |
+
".wma",
|
| 16 |
+
".aac",
|
| 17 |
+
".aiff",
|
| 18 |
+
".aif",
|
| 19 |
+
".aifc",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def list_files(
|
| 24 |
+
path: Union[Path, str],
|
| 25 |
+
extensions: set[str] = None,
|
| 26 |
+
recursive: bool = False,
|
| 27 |
+
sort: bool = True,
|
| 28 |
+
) -> list[Path]:
|
| 29 |
+
"""List files in a directory.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
path (Path): Path to the directory.
|
| 33 |
+
extensions (set, optional): Extensions to filter. Defaults to None.
|
| 34 |
+
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
| 35 |
+
sort (bool, optional): Whether to sort the files. Defaults to True.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
list: List of files.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
if isinstance(path, str):
|
| 42 |
+
path = Path(path)
|
| 43 |
+
|
| 44 |
+
if not path.exists():
|
| 45 |
+
raise FileNotFoundError(f"Directory {path} does not exist.")
|
| 46 |
+
|
| 47 |
+
files = [file for ext in extensions for file in path.iglob(f"**/*{ext}")]
|
| 48 |
+
|
| 49 |
+
if sort:
|
| 50 |
+
files = natsorted(files)
|
| 51 |
+
|
| 52 |
+
return files
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_latest_checkpoint(path: Path | str) -> Path | None:
|
| 56 |
+
# Find the latest checkpoint
|
| 57 |
+
ckpt_dir = Path(path)
|
| 58 |
+
|
| 59 |
+
if ckpt_dir.exists() is False:
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
|
| 63 |
+
if len(ckpts) == 0:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
return ckpts[-1]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
| 70 |
+
"""
|
| 71 |
+
Load a Bert-VITS2 style filelist.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
files = set()
|
| 75 |
+
results = []
|
| 76 |
+
count_duplicated, count_not_found = 0, 0
|
| 77 |
+
|
| 78 |
+
LANGUAGE_TO_LANGUAGES = {
|
| 79 |
+
"zh": ["zh", "en"],
|
| 80 |
+
"jp": ["jp", "en"],
|
| 81 |
+
"en": ["en"],
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 85 |
+
for line in f.readlines():
|
| 86 |
+
splits = line.strip().split("|", maxsplit=3)
|
| 87 |
+
if len(splits) != 4:
|
| 88 |
+
logger.warning(f"Invalid line: {line}")
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
filename, speaker, language, text = splits
|
| 92 |
+
file = Path(filename)
|
| 93 |
+
language = language.strip().lower()
|
| 94 |
+
|
| 95 |
+
if language == "ja":
|
| 96 |
+
language = "jp"
|
| 97 |
+
|
| 98 |
+
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
| 99 |
+
languages = LANGUAGE_TO_LANGUAGES[language]
|
| 100 |
+
|
| 101 |
+
if file in files:
|
| 102 |
+
logger.warning(f"Duplicated file: {file}")
|
| 103 |
+
count_duplicated += 1
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
if not file.exists():
|
| 107 |
+
logger.warning(f"File not found: {file}")
|
| 108 |
+
count_not_found += 1
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
results.append((file, speaker, languages, text))
|
| 112 |
+
|
| 113 |
+
if count_duplicated > 0:
|
| 114 |
+
logger.warning(f"Total duplicated files: {count_duplicated}")
|
| 115 |
+
|
| 116 |
+
if count_not_found > 0:
|
| 117 |
+
logger.warning(f"Total files not found: {count_not_found}")
|
| 118 |
+
|
| 119 |
+
return results
|
fish_speech/utils/instantiators.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
from pytorch_lightning import Callback
|
| 6 |
+
from pytorch_lightning.loggers import Logger
|
| 7 |
+
|
| 8 |
+
from .logger import RankedLogger
|
| 9 |
+
|
| 10 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
| 14 |
+
"""Instantiates callbacks from config."""
|
| 15 |
+
|
| 16 |
+
callbacks: List[Callback] = []
|
| 17 |
+
|
| 18 |
+
if not callbacks_cfg:
|
| 19 |
+
log.warning("No callback configs found! Skipping..")
|
| 20 |
+
return callbacks
|
| 21 |
+
|
| 22 |
+
if not isinstance(callbacks_cfg, DictConfig):
|
| 23 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
| 24 |
+
|
| 25 |
+
for _, cb_conf in callbacks_cfg.items():
|
| 26 |
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
| 27 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| 28 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
| 29 |
+
|
| 30 |
+
return callbacks
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
| 34 |
+
"""Instantiates loggers from config."""
|
| 35 |
+
|
| 36 |
+
logger: List[Logger] = []
|
| 37 |
+
|
| 38 |
+
if not logger_cfg:
|
| 39 |
+
log.warning("No logger configs found! Skipping...")
|
| 40 |
+
return logger
|
| 41 |
+
|
| 42 |
+
if not isinstance(logger_cfg, DictConfig):
|
| 43 |
+
raise TypeError("Logger config must be a DictConfig!")
|
| 44 |
+
|
| 45 |
+
for _, lg_conf in logger_cfg.items():
|
| 46 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
| 47 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 48 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
| 49 |
+
|
| 50 |
+
return logger
|
fish_speech/utils/logger.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RankedLogger(logging.LoggerAdapter):
|
| 8 |
+
"""A multi-GPU-friendly python command line logger."""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
name: str = __name__,
|
| 13 |
+
rank_zero_only: bool = True,
|
| 14 |
+
extra: Optional[Mapping[str, object]] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
| 17 |
+
with their rank prefixed in the log message.
|
| 18 |
+
|
| 19 |
+
:param name: The name of the logger. Default is ``__name__``.
|
| 20 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
| 21 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
| 22 |
+
"""
|
| 23 |
+
logger = logging.getLogger(name)
|
| 24 |
+
super().__init__(logger=logger, extra=extra)
|
| 25 |
+
self.rank_zero_only = rank_zero_only
|
| 26 |
+
|
| 27 |
+
def log(
|
| 28 |
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
| 31 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
| 32 |
+
occur on that rank/process.
|
| 33 |
+
|
| 34 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
| 35 |
+
:param msg: The message to log.
|
| 36 |
+
:param rank: The rank to log at.
|
| 37 |
+
:param args: Additional args to pass to the underlying logging function.
|
| 38 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
| 39 |
+
"""
|
| 40 |
+
if self.isEnabledFor(level):
|
| 41 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 42 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
| 43 |
+
if current_rank is None:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"The `rank_zero_only.rank` needs to be set before use"
|
| 46 |
+
)
|
| 47 |
+
msg = rank_prefixed_message(msg, current_rank)
|
| 48 |
+
if self.rank_zero_only:
|
| 49 |
+
if current_rank == 0:
|
| 50 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 51 |
+
else:
|
| 52 |
+
if rank is None:
|
| 53 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 54 |
+
elif current_rank == rank:
|
| 55 |
+
self.logger.log(level, msg, *args, **kwargs)
|
fish_speech/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lightning.pytorch.utilities import rank_zero_only
|
| 2 |
+
|
| 3 |
+
from fish_speech.utils import logger as log
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@rank_zero_only
|
| 7 |
+
def log_hyperparameters(object_dict: dict) -> None:
|
| 8 |
+
"""Controls which config parts are saved by lightning loggers.
|
| 9 |
+
|
| 10 |
+
Additionally saves:
|
| 11 |
+
- Number of model parameters
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
hparams = {}
|
| 15 |
+
|
| 16 |
+
cfg = object_dict["cfg"]
|
| 17 |
+
model = object_dict["model"]
|
| 18 |
+
trainer = object_dict["trainer"]
|
| 19 |
+
|
| 20 |
+
if not trainer.logger:
|
| 21 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
hparams["model"] = cfg["model"]
|
| 25 |
+
|
| 26 |
+
# save number of model parameters
|
| 27 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
| 28 |
+
hparams["model/params/trainable"] = sum(
|
| 29 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 30 |
+
)
|
| 31 |
+
hparams["model/params/non_trainable"] = sum(
|
| 32 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
hparams["data"] = cfg["data"]
|
| 36 |
+
hparams["trainer"] = cfg["trainer"]
|
| 37 |
+
|
| 38 |
+
hparams["callbacks"] = cfg.get("callbacks")
|
| 39 |
+
hparams["extras"] = cfg.get("extras")
|
| 40 |
+
|
| 41 |
+
hparams["task_name"] = cfg.get("task_name")
|
| 42 |
+
hparams["tags"] = cfg.get("tags")
|
| 43 |
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
| 44 |
+
hparams["seed"] = cfg.get("seed")
|
| 45 |
+
|
| 46 |
+
# send hparams to all loggers
|
| 47 |
+
for logger in trainer.loggers:
|
| 48 |
+
logger.log_hyperparams(hparams)
|
fish_speech/utils/rich_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import rich
|
| 5 |
+
import rich.syntax
|
| 6 |
+
import rich.tree
|
| 7 |
+
from hydra.core.hydra_config import HydraConfig
|
| 8 |
+
from lightning.pytorch.utilities import rank_zero_only
|
| 9 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
| 10 |
+
from rich.prompt import Prompt
|
| 11 |
+
|
| 12 |
+
from fish_speech.utils import logger as log
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@rank_zero_only
|
| 16 |
+
def print_config_tree(
|
| 17 |
+
cfg: DictConfig,
|
| 18 |
+
print_order: Sequence[str] = (
|
| 19 |
+
"data",
|
| 20 |
+
"model",
|
| 21 |
+
"callbacks",
|
| 22 |
+
"logger",
|
| 23 |
+
"trainer",
|
| 24 |
+
"paths",
|
| 25 |
+
"extras",
|
| 26 |
+
),
|
| 27 |
+
resolve: bool = False,
|
| 28 |
+
save_to_file: bool = False,
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
| 34 |
+
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
| 35 |
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
| 36 |
+
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
| 37 |
+
""" # noqa: E501
|
| 38 |
+
|
| 39 |
+
style = "dim"
|
| 40 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
| 41 |
+
|
| 42 |
+
queue = []
|
| 43 |
+
|
| 44 |
+
# add fields from `print_order` to queue
|
| 45 |
+
for field in print_order:
|
| 46 |
+
queue.append(field) if field in cfg else log.warning(
|
| 47 |
+
f"Field '{field}' not found in config. "
|
| 48 |
+
+ f"Skipping '{field}' config printing..."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# add all the other fields to queue (not specified in `print_order`)
|
| 52 |
+
for field in cfg:
|
| 53 |
+
if field not in queue:
|
| 54 |
+
queue.append(field)
|
| 55 |
+
|
| 56 |
+
# generate config tree from queue
|
| 57 |
+
for field in queue:
|
| 58 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 59 |
+
|
| 60 |
+
config_group = cfg[field]
|
| 61 |
+
if isinstance(config_group, DictConfig):
|
| 62 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
| 63 |
+
else:
|
| 64 |
+
branch_content = str(config_group)
|
| 65 |
+
|
| 66 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
| 67 |
+
|
| 68 |
+
# print config tree
|
| 69 |
+
rich.print(tree)
|
| 70 |
+
|
| 71 |
+
# save config tree to file
|
| 72 |
+
if save_to_file:
|
| 73 |
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
| 74 |
+
rich.print(tree, file=file)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@rank_zero_only
|
| 78 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
| 79 |
+
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
|
| 80 |
+
|
| 81 |
+
if not cfg.get("tags"):
|
| 82 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
| 83 |
+
raise ValueError("Specify tags before launching a multirun!")
|
| 84 |
+
|
| 85 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
| 86 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
| 87 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
| 88 |
+
|
| 89 |
+
with open_dict(cfg):
|
| 90 |
+
cfg.tags = tags
|
| 91 |
+
|
| 92 |
+
log.info(f"Tags: {cfg.tags}")
|
| 93 |
+
|
| 94 |
+
if save_to_file:
|
| 95 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
| 96 |
+
rich.print(cfg.tags, file=file)
|
fish_speech/utils/utils.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from importlib.util import find_spec
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
|
| 7 |
+
from .logger import RankedLogger
|
| 8 |
+
from .rich_utils import enforce_tags, print_config_tree
|
| 9 |
+
|
| 10 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extras(cfg: DictConfig) -> None:
|
| 14 |
+
"""Applies optional utilities before the task is started.
|
| 15 |
+
|
| 16 |
+
Utilities:
|
| 17 |
+
- Ignoring python warnings
|
| 18 |
+
- Setting tags from command line
|
| 19 |
+
- Rich config printing
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# return if no `extras` config
|
| 23 |
+
if not cfg.get("extras"):
|
| 24 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
# disable python warnings
|
| 28 |
+
if cfg.extras.get("ignore_warnings"):
|
| 29 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
+
|
| 32 |
+
# prompt user to input tags from command line if none are provided in the config
|
| 33 |
+
if cfg.extras.get("enforce_tags"):
|
| 34 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
| 35 |
+
enforce_tags(cfg, save_to_file=True)
|
| 36 |
+
|
| 37 |
+
# pretty print config tree using Rich library
|
| 38 |
+
if cfg.extras.get("print_config"):
|
| 39 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
| 40 |
+
print_config_tree(cfg, resolve=True, save_to_file=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
| 44 |
+
"""Optional decorator that controls the failure behavior when executing the task function.
|
| 45 |
+
|
| 46 |
+
This wrapper can be used to:
|
| 47 |
+
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
| 48 |
+
- save the exception to a `.log` file
|
| 49 |
+
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
| 50 |
+
- etc. (adjust depending on your needs)
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
```
|
| 54 |
+
@utils.task_wrapper
|
| 55 |
+
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
| 56 |
+
|
| 57 |
+
...
|
| 58 |
+
|
| 59 |
+
return metric_dict, object_dict
|
| 60 |
+
```
|
| 61 |
+
""" # noqa: E501
|
| 62 |
+
|
| 63 |
+
def wrap(cfg: DictConfig):
|
| 64 |
+
# execute the task
|
| 65 |
+
try:
|
| 66 |
+
metric_dict, object_dict = task_func(cfg=cfg)
|
| 67 |
+
|
| 68 |
+
# things to do if exception occurs
|
| 69 |
+
except Exception as ex:
|
| 70 |
+
# save exception to `.log` file
|
| 71 |
+
log.exception("")
|
| 72 |
+
|
| 73 |
+
# some hyperparameter combinations might be invalid or
|
| 74 |
+
# cause out-of-memory errors so when using hparam search
|
| 75 |
+
# plugins like Optuna, you might want to disable
|
| 76 |
+
# raising the below exception to avoid multirun failure
|
| 77 |
+
raise ex
|
| 78 |
+
|
| 79 |
+
# things to always do after either success or exception
|
| 80 |
+
finally:
|
| 81 |
+
# display output dir path in terminal
|
| 82 |
+
log.info(f"Output dir: {cfg.paths.run_dir}")
|
| 83 |
+
|
| 84 |
+
# always close wandb run (even if exception occurs so multirun won't fail)
|
| 85 |
+
if find_spec("wandb"): # check if wandb is installed
|
| 86 |
+
import wandb
|
| 87 |
+
|
| 88 |
+
if wandb.run:
|
| 89 |
+
log.info("Closing wandb!")
|
| 90 |
+
wandb.finish()
|
| 91 |
+
|
| 92 |
+
return metric_dict, object_dict
|
| 93 |
+
|
| 94 |
+
return wrap
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
| 98 |
+
"""Safely retrieves value of the metric logged in LightningModule."""
|
| 99 |
+
|
| 100 |
+
if not metric_name:
|
| 101 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
if metric_name not in metric_dict:
|
| 105 |
+
raise Exception(
|
| 106 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
| 107 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
| 108 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
metric_value = metric_dict[metric_name].item()
|
| 112 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
| 113 |
+
|
| 114 |
+
return metric_value
|
packages.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git
|
| 2 |
+
curl
|
| 3 |
+
build-essential
|
| 4 |
+
ffmpeg
|
| 5 |
+
libsm6
|
| 6 |
+
libxext6
|
| 7 |
+
libjpeg-dev
|
| 8 |
+
zlib1g-dev
|
| 9 |
+
protobuf-compiler
|
| 10 |
+
cmake
|
pyrightconfig.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"exclude": [
|
| 3 |
+
"data",
|
| 4 |
+
"filelists"
|
| 5 |
+
]
|
| 6 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
transformers>=4.35.2
|
| 4 |
+
datasets>=2.14.5
|
| 5 |
+
lightning>=2.1.0
|
| 6 |
+
hydra-core>=1.3.2
|
| 7 |
+
tensorboard>=2.14.1
|
| 8 |
+
natsort>=8.4.0
|
| 9 |
+
einops>=0.7.0
|
| 10 |
+
librosa>=0.10.1
|
| 11 |
+
rich>=13.5.3
|
| 12 |
+
gradio>=4.0.0
|
| 13 |
+
wandb>=0.15.11
|
| 14 |
+
grpcio>=1.58.0
|
| 15 |
+
kui>=1.6.0
|
| 16 |
+
zibai-server>=0.9.0
|
| 17 |
+
loguru>=0.6.0
|
| 18 |
+
loralib>=0.1.2
|
| 19 |
+
natsort>=8.4.0
|
| 20 |
+
pyrootutils>=1.0.4
|
| 21 |
+
vector_quantize_pytorch>=1.14.7
|
| 22 |
+
samplerate>=0.2.1
|
| 23 |
+
resampy>=0.4.3
|
| 24 |
+
spaces>=0.26.1"
|
setup.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
mkdir -p checkpoints
|
| 5 |
+
|
| 6 |
+
if [ -e checkpoints/text2semantic-medium-v1-2k.pth ]; then
|
| 7 |
+
echo "checkpoints/text2semantic-medium-v1-2k.pth already exists"
|
| 8 |
+
else
|
| 9 |
+
echo "Downloading text2semantic-medium-v1-2k.pth"
|
| 10 |
+
wget -O checkpoints/text2semantic-medium-v1-2k.pth $CKPT_SEMANTIC
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
if [ -e checkpoints/vq-gan-group-fsq-2x1024.pth ]; then
|
| 14 |
+
echo "checkpoints/vq-gan-group-fsq-2x1024.pth already exists"
|
| 15 |
+
else
|
| 16 |
+
echo "Downloading vq-gan-group-fsq-2x1024.pth"
|
| 17 |
+
wget -O checkpoints/vq-gan-group-fsq-2x1024.pth $CKPT_VQGAN
|
| 18 |
+
fi
|
tools/extract_model.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@click.command()
|
| 7 |
+
@click.argument("model_path")
|
| 8 |
+
@click.argument("output_path")
|
| 9 |
+
def main(model_path, output_path):
|
| 10 |
+
if model_path == output_path:
|
| 11 |
+
logger.error("Model path and output path are the same")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
logger.info(f"Loading model from {model_path}")
|
| 15 |
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
| 16 |
+
torch.save(state_dict, output_path)
|
| 17 |
+
logger.info(f"Model saved to {output_path}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
main()
|
tools/llama/build_dataset.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from functools import partial
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import click
|
| 10 |
+
import numpy as np
|
| 11 |
+
from loguru import logger
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
| 15 |
+
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
| 16 |
+
from fish_speech.utils.file import load_filelist
|
| 17 |
+
|
| 18 |
+
# To avoid CPU overload
|
| 19 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 20 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def task_generator_folder(root: Path, text_extension: str):
|
| 24 |
+
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
| 25 |
+
files = sorted(files)
|
| 26 |
+
|
| 27 |
+
grouped_files = defaultdict(list)
|
| 28 |
+
for file in tqdm(files, desc=f"Grouping {root}"):
|
| 29 |
+
p = str(file.parent)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
if isinstance(text_extension, str):
|
| 33 |
+
texts = [file.with_suffix(text_extension).read_text()]
|
| 34 |
+
else:
|
| 35 |
+
texts = [file.with_suffix(ext).read_text() for ext in text_extension]
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Failed to read text {file}: {e}")
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
grouped_files[p].append((file, texts))
|
| 41 |
+
|
| 42 |
+
logger.info(
|
| 43 |
+
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
| 44 |
+
)
|
| 45 |
+
for name, subset in grouped_files.items():
|
| 46 |
+
yield name, subset, "folder"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def task_generator_filelist(filelist):
|
| 50 |
+
grouped_files = defaultdict(list)
|
| 51 |
+
for filename, speaker, _, text in load_filelist(filelist):
|
| 52 |
+
grouped_files[speaker].append((Path(filename), [text]))
|
| 53 |
+
|
| 54 |
+
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
| 55 |
+
for speaker, values in grouped_files.items():
|
| 56 |
+
yield speaker, values, "filelist"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run_task(task):
|
| 60 |
+
name, subset, source = task
|
| 61 |
+
|
| 62 |
+
# Parse the files
|
| 63 |
+
sentences = []
|
| 64 |
+
for file in subset:
|
| 65 |
+
file, texts = file
|
| 66 |
+
|
| 67 |
+
np_file = file.with_suffix(".npy")
|
| 68 |
+
if np_file.exists() is False:
|
| 69 |
+
logger.warning(f"Can't find {np_file}")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
new_texts = []
|
| 73 |
+
|
| 74 |
+
for text in texts:
|
| 75 |
+
# Simple cleaning: replace { xxx } and < xxx > with space
|
| 76 |
+
text = re.sub(r"\{.*?\}", " ", text)
|
| 77 |
+
text = re.sub(r"<.*?>", " ", text)
|
| 78 |
+
text = re.sub(r"\s+", " ", text)
|
| 79 |
+
new_texts.append(text)
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
semantics = np.load(np_file)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Failed to parse {file}: {e}")
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
if isinstance(semantics, np.ndarray):
|
| 88 |
+
semantics = semantics.tolist()
|
| 89 |
+
|
| 90 |
+
sentences.append(
|
| 91 |
+
Sentence(
|
| 92 |
+
texts=new_texts,
|
| 93 |
+
semantics=[Semantics(values=s) for s in semantics],
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Pack the sentences
|
| 98 |
+
return pack_pb_stream(
|
| 99 |
+
TextData(
|
| 100 |
+
source=source,
|
| 101 |
+
name=name,
|
| 102 |
+
sentences=sentences,
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@click.command()
|
| 108 |
+
@click.option(
|
| 109 |
+
"--input",
|
| 110 |
+
type=click.Path(path_type=Path),
|
| 111 |
+
required=True,
|
| 112 |
+
help="A folder containing the dataset or a filelist",
|
| 113 |
+
multiple=True,
|
| 114 |
+
)
|
| 115 |
+
@click.option(
|
| 116 |
+
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
| 117 |
+
)
|
| 118 |
+
@click.option("--num-workers", type=int, default=16)
|
| 119 |
+
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
| 120 |
+
@click.option(
|
| 121 |
+
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
| 122 |
+
)
|
| 123 |
+
def main(input, output, num_workers, text_extension, shard_size):
|
| 124 |
+
generator_fns = []
|
| 125 |
+
|
| 126 |
+
for f in input:
|
| 127 |
+
assert f.exists(), f"{f} not found"
|
| 128 |
+
|
| 129 |
+
if f.is_dir():
|
| 130 |
+
generator_fn = task_generator_folder(f, text_extension)
|
| 131 |
+
else:
|
| 132 |
+
generator_fn = task_generator_filelist(f)
|
| 133 |
+
|
| 134 |
+
generator_fns.append(generator_fn)
|
| 135 |
+
|
| 136 |
+
generator_fn = itertools.chain(*generator_fns)
|
| 137 |
+
output.mkdir(parents=True, exist_ok=True)
|
| 138 |
+
|
| 139 |
+
dataset_fp = None
|
| 140 |
+
tar_idx = 0
|
| 141 |
+
written_size = 0
|
| 142 |
+
|
| 143 |
+
with Pool(num_workers) as p:
|
| 144 |
+
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
| 145 |
+
if dataset_fp is None:
|
| 146 |
+
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
| 147 |
+
|
| 148 |
+
dataset_fp.write(result)
|
| 149 |
+
written_size += len(result)
|
| 150 |
+
|
| 151 |
+
if written_size > shard_size * 1024 * 1024:
|
| 152 |
+
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
| 153 |
+
dataset_fp.close()
|
| 154 |
+
dataset_fp = None
|
| 155 |
+
written_size = 0
|
| 156 |
+
tar_idx += 1
|
| 157 |
+
|
| 158 |
+
if dataset_fp is not None:
|
| 159 |
+
dataset_fp.close()
|
| 160 |
+
|
| 161 |
+
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|