Spaces:
Running
on
L4
Running
on
L4
Upload folder using huggingface_hub
Browse files- examples/Arabic.wav +2 -2
- examples/English.wav +2 -2
- examples/French.wav +2 -2
- examples/German.wav +2 -2
- examples/Japanese.wav +2 -2
- examples/Korean.wav +2 -2
- examples/Nice English Ref.wav +2 -2
- examples/Spanish.wav +2 -2
- fish_speech/models/dac/modded_dac.py +0 -46
- fish_speech/models/text2semantic/inference.py +52 -114
- tools/download_models.py +2 -2
examples/Arabic.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a3c902c13fcf408c95353d91ab65f839d27584d8929c7345317956d1e9ea5bd
|
| 3 |
+
size 131
|
examples/English.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed744820849c8f16e03cb68e45b7d7d4b8697476a162d50ffe2cd6612a621aa6
|
| 3 |
+
size 131
|
examples/French.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dee830ddff631df6e0db0911a20099ddf6438a80d1da597536470ba36e2d645c
|
| 3 |
+
size 131
|
examples/German.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc076529638f0a4bb8d19b509b7781372c26abadcc74a7dcbc5b72b6b1e680fd
|
| 3 |
+
size 131
|
examples/Japanese.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba2a2c07770cb6ab36a5aa6ee953c9914773368e223359e4710897d425a25402
|
| 3 |
+
size 128
|
examples/Korean.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09c122b25a3ad99247179be77deeaa6ead7d93b40092347801948fea34797e48
|
| 3 |
+
size 128
|
examples/Nice English Ref.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b895ec0d49173630cf9253c70579888cde65129fbaeda167e3b4f91593715eca
|
| 3 |
+
size 128
|
examples/Spanish.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c22d63058f58f46c6a65b6ced8faa969f403b065e822a274342b520e8e20b65f
|
| 3 |
+
size 131
|
fish_speech/models/dac/modded_dac.py
CHANGED
|
@@ -976,49 +976,3 @@ class DAC(BaseModel, CodecMixin):
|
|
| 976 |
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
|
| 977 |
x = self.decode(z)
|
| 978 |
return x[..., :length], vq_results
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
if __name__ == "__main__":
|
| 982 |
-
|
| 983 |
-
def filter_state_dict_shapes(params, model):
|
| 984 |
-
model_state_dict = model.state_dict()
|
| 985 |
-
filtered_state_dict = {
|
| 986 |
-
k: v
|
| 987 |
-
for k, v in params.items()
|
| 988 |
-
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
| 989 |
-
}
|
| 990 |
-
skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
|
| 991 |
-
if skipped_keys:
|
| 992 |
-
print(
|
| 993 |
-
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
| 994 |
-
)
|
| 995 |
-
return filtered_state_dict, skipped_keys
|
| 996 |
-
|
| 997 |
-
model = hydra.utils.instantiate(
|
| 998 |
-
OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
|
| 999 |
-
)
|
| 1000 |
-
sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
|
| 1001 |
-
filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
|
| 1002 |
-
print(f"Skipped keys: {skipped_keys}")
|
| 1003 |
-
model.load_state_dict(filtered_sd, strict=False)
|
| 1004 |
-
model.eval()
|
| 1005 |
-
|
| 1006 |
-
src_audio_path = "./test.wav"
|
| 1007 |
-
wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
|
| 1008 |
-
if len(wave_np.shape) == 1:
|
| 1009 |
-
wave_np = wave_np[None, :]
|
| 1010 |
-
wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
|
| 1011 |
-
|
| 1012 |
-
with torch.no_grad():
|
| 1013 |
-
# encode 返回 (indices, indices_lens)
|
| 1014 |
-
indices, indices_lens = model.encode(wave_tensor)
|
| 1015 |
-
print(f"Indices shape: {indices.shape}")
|
| 1016 |
-
print(f"Indices lengths: {indices_lens}")
|
| 1017 |
-
|
| 1018 |
-
# decode 需要 indices 和 feature_lengths 两个参数
|
| 1019 |
-
fake_audio, audio_lengths = model.decode(indices, indices_lens)
|
| 1020 |
-
print(f"Decoded audio shape: {fake_audio.shape}")
|
| 1021 |
-
print(f"Audio lengths: {audio_lengths}")
|
| 1022 |
-
|
| 1023 |
-
# 保存重建的音频
|
| 1024 |
-
sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
|
|
|
|
| 976 |
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
|
| 977 |
x = self.decode(z)
|
| 978 |
return x[..., :length], vq_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/text2semantic/inference.py
CHANGED
|
@@ -10,7 +10,6 @@ from typing import Literal, Optional, Tuple, Union
|
|
| 10 |
import click
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
-
import torch._dynamo.config
|
| 14 |
import torch._inductor.config
|
| 15 |
from loguru import logger
|
| 16 |
from tqdm import tqdm
|
|
@@ -21,9 +20,8 @@ from fish_speech.content_sequence import (
|
|
| 21 |
TextPart,
|
| 22 |
VQPart,
|
| 23 |
)
|
| 24 |
-
from fish_speech.
|
| 25 |
-
from fish_speech.
|
| 26 |
-
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
| 27 |
|
| 28 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 29 |
torch._inductor.config.coordinate_descent_tuning = True
|
|
@@ -37,7 +35,6 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
| 37 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 38 |
|
| 39 |
from fish_speech.models.text2semantic.llama import (
|
| 40 |
-
BaseTransformer,
|
| 41 |
DualARTransformer,
|
| 42 |
NaiveTransformer,
|
| 43 |
)
|
|
@@ -98,16 +95,27 @@ def decode_one_token_ar(
|
|
| 98 |
model: DualARTransformer,
|
| 99 |
x: torch.Tensor,
|
| 100 |
input_pos: torch.Tensor,
|
| 101 |
-
semantic_ids: list,
|
| 102 |
previous_tokens: torch.Tensor = None,
|
| 103 |
**sampling_kwargs,
|
| 104 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
x = model.forward_generate(x, input_pos)
|
| 106 |
|
| 107 |
sampling_kwargs_main = sampling_kwargs.copy()
|
| 108 |
-
# sampling_kwargs_main["temperature"] = 0.1
|
| 109 |
-
# sampling_kwargs_main["top_p"] = 0.1
|
| 110 |
-
# sampling_kwargs_main["repetition_penalty"] = 1.0
|
| 111 |
|
| 112 |
codebooks = [
|
| 113 |
sample(
|
|
@@ -152,12 +160,7 @@ def decode_one_token_ar(
|
|
| 152 |
codebooks.append(a)
|
| 153 |
|
| 154 |
codebooks = torch.stack(codebooks, dim=0)
|
| 155 |
-
# semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
| 156 |
-
# codebooks[1:, :] = torch.masked_fill(
|
| 157 |
-
# codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
|
| 158 |
-
# )
|
| 159 |
|
| 160 |
-
# print(codebooks)
|
| 161 |
return codebooks
|
| 162 |
|
| 163 |
|
|
@@ -166,10 +169,24 @@ def decode_n_tokens(
|
|
| 166 |
cur_token: torch.Tensor,
|
| 167 |
input_pos: torch.Tensor,
|
| 168 |
num_new_tokens: int,
|
| 169 |
-
semantic_ids: list,
|
| 170 |
decode_one_token=decode_one_token_ar,
|
| 171 |
**sampling_kwargs,
|
| 172 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
previous_tokens = torch.zeros(
|
| 174 |
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
| 175 |
dtype=torch.int,
|
|
@@ -184,21 +201,14 @@ def decode_n_tokens(
|
|
| 184 |
else:
|
| 185 |
window = previous_tokens[:, i - win_size : i]
|
| 186 |
|
| 187 |
-
with (
|
| 188 |
-
torch.backends.cuda.sdp_kernel(
|
| 189 |
-
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
| 190 |
-
)
|
| 191 |
-
if torch.cuda.is_available()
|
| 192 |
-
else nullcontext()
|
| 193 |
-
): # Actually better for Inductor to codegen attention here
|
| 194 |
next_token = decode_one_token(
|
| 195 |
model=model,
|
| 196 |
x=cur_token,
|
| 197 |
input_pos=input_pos,
|
| 198 |
previous_tokens=window,
|
| 199 |
-
semantic_ids=semantic_ids,
|
| 200 |
**sampling_kwargs,
|
| 201 |
-
)
|
| 202 |
|
| 203 |
input_pos += 1
|
| 204 |
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
@@ -223,15 +233,21 @@ def generate(
|
|
| 223 |
**sampling_kwargs,
|
| 224 |
) -> torch.Tensor:
|
| 225 |
"""
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
|
| 229 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 230 |
T = prompt.size(1)
|
| 231 |
-
# semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
| 232 |
-
semantic_ids = [
|
| 233 |
-
model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
| 234 |
-
]
|
| 235 |
|
| 236 |
if max_new_tokens:
|
| 237 |
if T + max_new_tokens > model.config.max_seq_len:
|
|
@@ -246,7 +262,6 @@ def generate(
|
|
| 246 |
device, dtype = prompt.device, prompt.dtype
|
| 247 |
|
| 248 |
codebook_dim = 1 + model.config.num_codebooks
|
| 249 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 250 |
empty = torch.empty(
|
| 251 |
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
| 252 |
)
|
|
@@ -257,33 +272,30 @@ def generate(
|
|
| 257 |
# Use non-accelerated version for now, to avoid compilation overhead
|
| 258 |
prefill_decode = decode_one_token_ar
|
| 259 |
|
| 260 |
-
|
| 261 |
model,
|
| 262 |
prompt.view(1, codebook_dim, -1),
|
| 263 |
input_pos,
|
| 264 |
-
semantic_ids=semantic_ids,
|
| 265 |
**sampling_kwargs,
|
| 266 |
)
|
| 267 |
-
seq[:, T : T + 1] =
|
| 268 |
|
| 269 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
| 270 |
x = decode_n_tokens(
|
| 271 |
model,
|
| 272 |
-
|
| 273 |
input_pos,
|
| 274 |
max_new_tokens - 1,
|
| 275 |
decode_one_token=decode_one_token,
|
| 276 |
-
semantic_ids=semantic_ids,
|
| 277 |
**sampling_kwargs,
|
| 278 |
)
|
| 279 |
-
# x = torch.cat(generated_tokens, dim=1)
|
| 280 |
seq = seq[:, : T + 1 + x.size(1)]
|
| 281 |
seq[:, T + 1 :] = x
|
| 282 |
|
| 283 |
return seq
|
| 284 |
|
| 285 |
|
| 286 |
-
def
|
| 287 |
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
| 288 |
|
| 289 |
model = model.to(device=device, dtype=precision)
|
|
@@ -405,26 +417,6 @@ def generate_long(
|
|
| 405 |
seg = encoded[seg_idx]
|
| 406 |
global_encoded.append(seg)
|
| 407 |
|
| 408 |
-
# Do not use previous segments to generate current segment for now
|
| 409 |
-
# lengths = reversed([seg.size(1) for seg in global_encoded])
|
| 410 |
-
|
| 411 |
-
# # Pick last 2000 tokens
|
| 412 |
-
# count = 0
|
| 413 |
-
# for i, length in enumerate(lengths):
|
| 414 |
-
# count += length
|
| 415 |
-
# if count + length > max_length - 2048 - encoded_prompts.size(1):
|
| 416 |
-
# break
|
| 417 |
-
|
| 418 |
-
# if i != 0 and i % 2 == 0:
|
| 419 |
-
# i -= 1
|
| 420 |
-
|
| 421 |
-
# # Rotate the list, always make sure first segment is included to avoid drift
|
| 422 |
-
# if i < len(global_encoded) - 2:
|
| 423 |
-
# partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
| 424 |
-
# else:
|
| 425 |
-
# partial_encoded = global_encoded
|
| 426 |
-
|
| 427 |
-
# cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
|
| 428 |
if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
|
| 429 |
cat_encoded = torch.cat(
|
| 430 |
[encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
|
|
@@ -507,7 +499,7 @@ def launch_thread_safe_queue(
|
|
| 507 |
init_event = threading.Event()
|
| 508 |
|
| 509 |
def worker():
|
| 510 |
-
model, decode_one_token =
|
| 511 |
checkpoint_path, device, precision, compile=compile
|
| 512 |
)
|
| 513 |
with torch.device(device):
|
|
@@ -542,60 +534,6 @@ def launch_thread_safe_queue(
|
|
| 542 |
return input_queue
|
| 543 |
|
| 544 |
|
| 545 |
-
def launch_thread_safe_queue_agent(
|
| 546 |
-
checkpoint_path,
|
| 547 |
-
device,
|
| 548 |
-
precision,
|
| 549 |
-
compile: bool = False,
|
| 550 |
-
):
|
| 551 |
-
input_queue = queue.Queue()
|
| 552 |
-
init_event = threading.Event()
|
| 553 |
-
|
| 554 |
-
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
| 555 |
-
config = BaseModelArgs.from_pretrained(checkpoint_path)
|
| 556 |
-
|
| 557 |
-
def worker():
|
| 558 |
-
model, decode_one_token = load_model(
|
| 559 |
-
checkpoint_path, device, precision, compile=compile, is_agent=True
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
with torch.device(device):
|
| 563 |
-
model.setup_caches(
|
| 564 |
-
max_batch_size=1,
|
| 565 |
-
max_seq_len=model.config.max_seq_len,
|
| 566 |
-
dtype=next(model.parameters()).dtype,
|
| 567 |
-
)
|
| 568 |
-
init_event.set()
|
| 569 |
-
|
| 570 |
-
while True:
|
| 571 |
-
item: GenerateRequest | None = input_queue.get()
|
| 572 |
-
if item is None:
|
| 573 |
-
break
|
| 574 |
-
|
| 575 |
-
kwargs = item.request
|
| 576 |
-
response_queue = item.response_queue
|
| 577 |
-
|
| 578 |
-
try:
|
| 579 |
-
for token in generate_agent(
|
| 580 |
-
model=model,
|
| 581 |
-
decode_one_token=decode_one_token,
|
| 582 |
-
**kwargs,
|
| 583 |
-
):
|
| 584 |
-
response_queue.put(token)
|
| 585 |
-
|
| 586 |
-
response_queue.put("stop")
|
| 587 |
-
except Exception as e:
|
| 588 |
-
import traceback
|
| 589 |
-
|
| 590 |
-
logger.exception(f"Error in worker: {traceback.format_exc()}")
|
| 591 |
-
response_queue.put("error")
|
| 592 |
-
|
| 593 |
-
threading.Thread(target=worker, daemon=True).start()
|
| 594 |
-
init_event.wait()
|
| 595 |
-
|
| 596 |
-
return input_queue, tokenizer, config
|
| 597 |
-
|
| 598 |
-
|
| 599 |
@click.command()
|
| 600 |
@click.option(
|
| 601 |
"--text",
|
|
@@ -654,7 +592,7 @@ def main(
|
|
| 654 |
|
| 655 |
logger.info("Loading model ...")
|
| 656 |
t0 = time.time()
|
| 657 |
-
model, decode_one_token =
|
| 658 |
checkpoint_path, device, precision, compile=compile
|
| 659 |
)
|
| 660 |
with torch.device(device):
|
|
|
|
| 10 |
import click
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
|
|
|
| 13 |
import torch._inductor.config
|
| 14 |
from loguru import logger
|
| 15 |
from tqdm import tqdm
|
|
|
|
| 20 |
TextPart,
|
| 21 |
VQPart,
|
| 22 |
)
|
| 23 |
+
from fish_speech.text import split_text
|
| 24 |
+
from fish_speech.tokenizer import IM_END_TOKEN
|
|
|
|
| 25 |
|
| 26 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 27 |
torch._inductor.config.coordinate_descent_tuning = True
|
|
|
|
| 35 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 36 |
|
| 37 |
from fish_speech.models.text2semantic.llama import (
|
|
|
|
| 38 |
DualARTransformer,
|
| 39 |
NaiveTransformer,
|
| 40 |
)
|
|
|
|
| 95 |
model: DualARTransformer,
|
| 96 |
x: torch.Tensor,
|
| 97 |
input_pos: torch.Tensor,
|
|
|
|
| 98 |
previous_tokens: torch.Tensor = None,
|
| 99 |
**sampling_kwargs,
|
| 100 |
) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Generate one token using dual autoregressive transformer for text-to-speech.
|
| 103 |
+
|
| 104 |
+
First generates semantic tokens, then generates acoustic codebook tokens sequentially.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
x: Input token tensor (1, num_codebooks+1, seq_len)
|
| 108 |
+
input_pos: Position indices for input tokens (seq_len,)
|
| 109 |
+
temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
|
| 110 |
+
previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
|
| 111 |
+
audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
|
| 115 |
+
"""
|
| 116 |
x = model.forward_generate(x, input_pos)
|
| 117 |
|
| 118 |
sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
codebooks = [
|
| 121 |
sample(
|
|
|
|
| 160 |
codebooks.append(a)
|
| 161 |
|
| 162 |
codebooks = torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
|
|
|
| 164 |
return codebooks
|
| 165 |
|
| 166 |
|
|
|
|
| 169 |
cur_token: torch.Tensor,
|
| 170 |
input_pos: torch.Tensor,
|
| 171 |
num_new_tokens: int,
|
|
|
|
| 172 |
decode_one_token=decode_one_token_ar,
|
| 173 |
**sampling_kwargs,
|
| 174 |
):
|
| 175 |
+
"""
|
| 176 |
+
Generate n tokens iteratively using the model.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
model: The transformer model
|
| 180 |
+
cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
|
| 181 |
+
input_pos: Current input position tensor
|
| 182 |
+
num_new_tokens: Number of new tokens to generate
|
| 183 |
+
semantic_ids: List of semantic token IDs
|
| 184 |
+
decode_one_token: Function to decode one token
|
| 185 |
+
**sampling_kwargs: Additional sampling parameters
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Generated tokens tensor of shape (num_codebooks+1, generated_len)
|
| 189 |
+
"""
|
| 190 |
previous_tokens = torch.zeros(
|
| 191 |
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
| 192 |
dtype=torch.int,
|
|
|
|
| 201 |
else:
|
| 202 |
window = previous_tokens[:, i - win_size : i]
|
| 203 |
|
| 204 |
+
with sdpa_kernel(SDPBackend.MATH):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
next_token = decode_one_token(
|
| 206 |
model=model,
|
| 207 |
x=cur_token,
|
| 208 |
input_pos=input_pos,
|
| 209 |
previous_tokens=window,
|
|
|
|
| 210 |
**sampling_kwargs,
|
| 211 |
+
).clone()
|
| 212 |
|
| 213 |
input_pos += 1
|
| 214 |
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
|
|
| 233 |
**sampling_kwargs,
|
| 234 |
) -> torch.Tensor:
|
| 235 |
"""
|
| 236 |
+
Generate tokens from text prompt using the transformer model.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
model: The transformer model for generation
|
| 240 |
+
prompt: Input token tensor of shape (num_codebooks+1, seq_len)
|
| 241 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 242 |
+
decode_one_token: Function to decode one token at a time
|
| 243 |
+
**sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
|
| 247 |
+
where total_seq_len = original_seq_len + generated_tokens_len
|
| 248 |
"""
|
| 249 |
|
|
|
|
| 250 |
T = prompt.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
if max_new_tokens:
|
| 253 |
if T + max_new_tokens > model.config.max_seq_len:
|
|
|
|
| 262 |
device, dtype = prompt.device, prompt.dtype
|
| 263 |
|
| 264 |
codebook_dim = 1 + model.config.num_codebooks
|
|
|
|
| 265 |
empty = torch.empty(
|
| 266 |
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
| 267 |
)
|
|
|
|
| 272 |
# Use non-accelerated version for now, to avoid compilation overhead
|
| 273 |
prefill_decode = decode_one_token_ar
|
| 274 |
|
| 275 |
+
first_token = prefill_decode(
|
| 276 |
model,
|
| 277 |
prompt.view(1, codebook_dim, -1),
|
| 278 |
input_pos,
|
|
|
|
| 279 |
**sampling_kwargs,
|
| 280 |
)
|
| 281 |
+
seq[:, T : T + 1] = first_token
|
| 282 |
|
| 283 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
| 284 |
x = decode_n_tokens(
|
| 285 |
model,
|
| 286 |
+
first_token.view(1, codebook_dim, -1),
|
| 287 |
input_pos,
|
| 288 |
max_new_tokens - 1,
|
| 289 |
decode_one_token=decode_one_token,
|
|
|
|
| 290 |
**sampling_kwargs,
|
| 291 |
)
|
|
|
|
| 292 |
seq = seq[:, : T + 1 + x.size(1)]
|
| 293 |
seq[:, T + 1 :] = x
|
| 294 |
|
| 295 |
return seq
|
| 296 |
|
| 297 |
|
| 298 |
+
def init_model(checkpoint_path, device, precision, compile=False):
|
| 299 |
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
| 300 |
|
| 301 |
model = model.to(device=device, dtype=precision)
|
|
|
|
| 417 |
seg = encoded[seg_idx]
|
| 418 |
global_encoded.append(seg)
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
|
| 421 |
cat_encoded = torch.cat(
|
| 422 |
[encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
|
|
|
|
| 499 |
init_event = threading.Event()
|
| 500 |
|
| 501 |
def worker():
|
| 502 |
+
model, decode_one_token = init_model(
|
| 503 |
checkpoint_path, device, precision, compile=compile
|
| 504 |
)
|
| 505 |
with torch.device(device):
|
|
|
|
| 534 |
return input_queue
|
| 535 |
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
@click.command()
|
| 538 |
@click.option(
|
| 539 |
"--text",
|
|
|
|
| 592 |
|
| 593 |
logger.info("Loading model ...")
|
| 594 |
t0 = time.time()
|
| 595 |
+
model, decode_one_token = init_model(
|
| 596 |
checkpoint_path, device, precision, compile=compile
|
| 597 |
)
|
| 598 |
with torch.device(device):
|
tools/download_models.py
CHANGED
|
@@ -22,7 +22,7 @@ def check_and_download_files(repo_id, file_list, local_dir):
|
|
| 22 |
|
| 23 |
|
| 24 |
# 1st
|
| 25 |
-
repo_id_1 = "fishaudio/
|
| 26 |
local_dir_1 = "./checkpoints/openaudio-s1-mini"
|
| 27 |
files_1 = [
|
| 28 |
".gitattributes",
|
|
@@ -31,7 +31,7 @@ files_1 = [
|
|
| 31 |
"special_tokens.json",
|
| 32 |
"tokenizer.tiktoken",
|
| 33 |
"config.json",
|
| 34 |
-
"
|
| 35 |
]
|
| 36 |
|
| 37 |
# 3rd
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# 1st
|
| 25 |
+
repo_id_1 = "fishaudio/openaudio-s1-mini"
|
| 26 |
local_dir_1 = "./checkpoints/openaudio-s1-mini"
|
| 27 |
files_1 = [
|
| 28 |
".gitattributes",
|
|
|
|
| 31 |
"special_tokens.json",
|
| 32 |
"tokenizer.tiktoken",
|
| 33 |
"config.json",
|
| 34 |
+
"codec.pth",
|
| 35 |
]
|
| 36 |
|
| 37 |
# 3rd
|