Spaces:
Running
on
Zero
Running
on
Zero
switch to config arg
Browse files
models.py
CHANGED
|
@@ -102,27 +102,22 @@ class Model(
|
|
| 102 |
repo_url="https://github.com/SesameAILabs/csm",
|
| 103 |
pipeline_tag="text-to-speech",
|
| 104 |
license="apache-2.0",
|
| 105 |
-
coders={
|
| 106 |
-
# Tells the class how to serialize and deserialize config.json
|
| 107 |
-
ModelArgs : (
|
| 108 |
-
lambda x: asdict(x), # Encoder: how to convert a `ModelArgs` to a valid jsonable value?
|
| 109 |
-
lambda data: ModelArgs(**data), # Decoder: how to reconstruct a `ModelArgs` from a dictionary?
|
| 110 |
-
)
|
| 111 |
-
}
|
| 112 |
):
|
| 113 |
-
def __init__(self,
|
| 114 |
super().__init__()
|
| 115 |
-
self.
|
| 116 |
|
| 117 |
-
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[
|
| 118 |
-
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[
|
| 119 |
|
| 120 |
-
self.text_embeddings = nn.Embedding(
|
| 121 |
-
self.audio_embeddings = nn.Embedding(
|
| 122 |
|
| 123 |
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
| 124 |
-
self.codebook0_head = nn.Linear(backbone_dim,
|
| 125 |
-
self.audio_head = nn.Parameter(
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def setup_caches(self, max_batch_size: int) -> None:
|
| 128 |
"""Setup KV caches and return a causal mask."""
|
|
@@ -131,10 +126,10 @@ class Model(
|
|
| 131 |
|
| 132 |
with device:
|
| 133 |
self.backbone.setup_caches(max_batch_size, dtype)
|
| 134 |
-
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.
|
| 135 |
|
| 136 |
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
| 137 |
-
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.
|
| 138 |
|
| 139 |
def generate_frame(
|
| 140 |
self,
|
|
@@ -175,7 +170,7 @@ class Model(
|
|
| 175 |
|
| 176 |
# Decoder caches must be reset every frame.
|
| 177 |
self.decoder.reset_caches()
|
| 178 |
-
for i in range(1, self.
|
| 179 |
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
| 180 |
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
| 181 |
dtype=dtype
|
|
@@ -195,16 +190,16 @@ class Model(
|
|
| 195 |
self.decoder.reset_caches()
|
| 196 |
|
| 197 |
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
| 198 |
-
return self.audio_embeddings(tokens + codebook * self.
|
| 199 |
|
| 200 |
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 201 |
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
| 202 |
|
| 203 |
audio_tokens = tokens[:, :, :-1] + (
|
| 204 |
-
self.
|
| 205 |
)
|
| 206 |
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 207 |
-
tokens.size(0), tokens.size(1), self.
|
| 208 |
)
|
| 209 |
|
| 210 |
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
|
|
|
| 102 |
repo_url="https://github.com/SesameAILabs/csm",
|
| 103 |
pipeline_tag="text-to-speech",
|
| 104 |
license="apache-2.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
):
|
| 106 |
+
def __init__(self, config: ModelArgs):
|
| 107 |
super().__init__()
|
| 108 |
+
self.config = config
|
| 109 |
|
| 110 |
+
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
| 111 |
+
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
| 112 |
|
| 113 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
| 114 |
+
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
| 115 |
|
| 116 |
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
| 117 |
+
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
| 118 |
+
self.audio_head = nn.Parameter(
|
| 119 |
+
torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)
|
| 120 |
+
)
|
| 121 |
|
| 122 |
def setup_caches(self, max_batch_size: int) -> None:
|
| 123 |
"""Setup KV caches and return a causal mask."""
|
|
|
|
| 126 |
|
| 127 |
with device:
|
| 128 |
self.backbone.setup_caches(max_batch_size, dtype)
|
| 129 |
+
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
| 130 |
|
| 131 |
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
| 132 |
+
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
| 133 |
|
| 134 |
def generate_frame(
|
| 135 |
self,
|
|
|
|
| 170 |
|
| 171 |
# Decoder caches must be reset every frame.
|
| 172 |
self.decoder.reset_caches()
|
| 173 |
+
for i in range(1, self.config.audio_num_codebooks):
|
| 174 |
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
| 175 |
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
| 176 |
dtype=dtype
|
|
|
|
| 190 |
self.decoder.reset_caches()
|
| 191 |
|
| 192 |
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
| 194 |
|
| 195 |
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 196 |
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
| 197 |
|
| 198 |
audio_tokens = tokens[:, :, :-1] + (
|
| 199 |
+
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
| 200 |
)
|
| 201 |
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 202 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
| 203 |
)
|
| 204 |
|
| 205 |
return torch.cat([audio_embeds, text_embeds], dim=-2)
|