Spaces:
Running
on
Zero
Running
on
Zero
v1.2 edit
Browse files- app.py +5 -1
- diffrhythm/model/cfm.py +10 -16
- diffrhythm/model/dit.py +7 -3
app.py
CHANGED
|
@@ -232,8 +232,12 @@ with gr.Blocks(css=css) as demo:
|
|
| 232 |
3. **Supported Languages**
|
| 233 |
- **Chinese and English**
|
| 234 |
- More languages comming soon
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
| 238 |
|
| 239 |
""")
|
|
|
|
| 232 |
3. **Supported Languages**
|
| 233 |
- **Chinese and English**
|
| 234 |
- More languages comming soon
|
| 235 |
+
|
| 236 |
+
4. **Editing Function in Advanced Settings**
|
| 237 |
+
- Using full-length audio as reference is recommended for best results.
|
| 238 |
+
- Use -1 to represent the start/end of audio (e.g. [[-1,25], [50,-1]] means "from start to 25s" and "from 50s to end").
|
| 239 |
|
| 240 |
+
5. **Others**
|
| 241 |
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
| 242 |
|
| 243 |
""")
|
diffrhythm/model/cfm.py
CHANGED
|
@@ -208,27 +208,21 @@ class CFM(nn.Module):
|
|
| 208 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
| 209 |
start_time = start_time.repeat(batch_infer_num)
|
| 210 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
| 211 |
-
|
| 212 |
-
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
| 213 |
-
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
| 214 |
-
|
| 215 |
-
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
| 216 |
-
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
| 217 |
-
step_cond = torch.cat([step_cond, step_cond], 0)
|
| 218 |
-
style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
|
| 219 |
-
start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
|
| 220 |
|
| 221 |
def fn(t, x):
|
| 222 |
-
|
| 223 |
pred = self.transformer(
|
| 224 |
-
x=x,
|
| 225 |
-
|
| 226 |
)
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 232 |
|
| 233 |
# noise input
|
| 234 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
|
|
|
| 208 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
| 209 |
start_time = start_time.repeat(batch_infer_num)
|
| 210 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def fn(t, x):
|
| 213 |
+
# predict flow
|
| 214 |
pred = self.transformer(
|
| 215 |
+
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
|
| 216 |
+
style_prompt=style_prompt, start_time=start_time
|
| 217 |
)
|
| 218 |
+
if cfg_strength < 1e-5:
|
| 219 |
+
return pred
|
| 220 |
|
| 221 |
+
null_pred = self.transformer(
|
| 222 |
+
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
|
| 223 |
+
style_prompt=negative_style_prompt, start_time=start_time
|
| 224 |
+
)
|
| 225 |
+
return pred + (pred - null_pred) * cfg_strength
|
| 226 |
|
| 227 |
# noise input
|
| 228 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
diffrhythm/model/dit.py
CHANGED
|
@@ -162,21 +162,25 @@ class DiT(nn.Module):
|
|
| 162 |
def forward(
|
| 163 |
self,
|
| 164 |
x: float["b n d"], # nosied input audio # noqa: F722
|
| 165 |
-
text_embed: int["b nt"], # text # noqa: F722
|
| 166 |
-
text_residuals,
|
| 167 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
|
|
| 168 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 169 |
drop_audio_cond, # cfg for cond audio
|
|
|
|
| 170 |
drop_prompt=False,
|
| 171 |
style_prompt=None, # [b d t]
|
| 172 |
start_time=None,
|
| 173 |
):
|
|
|
|
| 174 |
batch, seq_len = x.shape[0], x.shape[1]
|
| 175 |
if time.ndim == 0:
|
| 176 |
time = time.repeat(batch)
|
| 177 |
|
|
|
|
| 178 |
t = self.time_embed(time)
|
| 179 |
-
|
|
|
|
|
|
|
| 180 |
|
| 181 |
if drop_prompt:
|
| 182 |
style_prompt = torch.zeros_like(style_prompt)
|
|
|
|
| 162 |
def forward(
|
| 163 |
self,
|
| 164 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
|
|
|
|
| 165 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 166 |
+
text: int["b nt"], # text # noqa: F722
|
| 167 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 168 |
drop_audio_cond, # cfg for cond audio
|
| 169 |
+
drop_text, # cfg for text
|
| 170 |
drop_prompt=False,
|
| 171 |
style_prompt=None, # [b d t]
|
| 172 |
start_time=None,
|
| 173 |
):
|
| 174 |
+
|
| 175 |
batch, seq_len = x.shape[0], x.shape[1]
|
| 176 |
if time.ndim == 0:
|
| 177 |
time = time.repeat(batch)
|
| 178 |
|
| 179 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 180 |
t = self.time_embed(time)
|
| 181 |
+
s_t = self.start_time_embed(start_time)
|
| 182 |
+
c = t + s_t
|
| 183 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
| 184 |
|
| 185 |
if drop_prompt:
|
| 186 |
style_prompt = torch.zeros_like(style_prompt)
|