Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
| 10 |
import io
|
| 11 |
import pydub
|
| 12 |
import base64
|
|
|
|
| 13 |
from muq import MuQMuLan
|
| 14 |
from diffrhythm2.cfm import CFM
|
| 15 |
from diffrhythm2.backbones.dit import DiT
|
|
@@ -47,7 +48,8 @@ class CNENTokenizer():
|
|
| 47 |
return token
|
| 48 |
def decode(self, token):
|
| 49 |
return "|".join([self.id2phone[x-1] for x in token])
|
| 50 |
-
|
|
|
|
| 51 |
def prepare_model(repo_id, device, dtype):
|
| 52 |
diffrhythm2_ckpt_path = hf_hub_download(
|
| 53 |
repo_id=repo_id,
|
|
@@ -121,6 +123,7 @@ def parse_lyrics(lyrics: str):
|
|
| 121 |
lyrics_with_time.append(tokens)
|
| 122 |
return lyrics_with_time
|
| 123 |
|
|
|
|
| 124 |
def get_audio_prompt(model, audio_file, device, dtype):
|
| 125 |
prompt_wav, sr = torchaudio.load(audio_file)
|
| 126 |
prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
|
|
@@ -132,11 +135,13 @@ def get_audio_prompt(model, audio_file, device, dtype):
|
|
| 132 |
style_prompt_embed = model(wavs = prompt_wav)
|
| 133 |
return style_prompt_embed.squeeze(0)
|
| 134 |
|
|
|
|
| 135 |
def get_text_prompt(model, text, device, dtype):
|
| 136 |
with torch.no_grad():
|
| 137 |
style_prompt_embed = model(texts = [text])
|
| 138 |
return style_prompt_embed.squeeze(0)
|
| 139 |
|
|
|
|
| 140 |
def make_fake_stereo(audio, sampling_rate):
|
| 141 |
left_channel = audio
|
| 142 |
right_channel = audio.clone()
|
|
@@ -148,7 +153,8 @@ def make_fake_stereo(audio, sampling_rate):
|
|
| 148 |
stereo_audio = torch.cat([left_channel, right_channel], dim=0)
|
| 149 |
|
| 150 |
return stereo_audio
|
| 151 |
-
|
|
|
|
| 152 |
def inference(
|
| 153 |
model,
|
| 154 |
decoder,
|
|
@@ -186,6 +192,7 @@ def inference(
|
|
| 186 |
torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
|
| 187 |
return buffer.getvalue()
|
| 188 |
|
|
|
|
| 189 |
def inference_stream(
|
| 190 |
model,
|
| 191 |
decoder,
|
|
@@ -224,7 +231,7 @@ device='cuda'
|
|
| 224 |
dtype=torch.float16
|
| 225 |
diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
|
| 226 |
|
| 227 |
-
|
| 228 |
@spaces.GPU
|
| 229 |
def infer_music(
|
| 230 |
lrc,
|
|
|
|
| 10 |
import io
|
| 11 |
import pydub
|
| 12 |
import base64
|
| 13 |
+
import spaces
|
| 14 |
from muq import MuQMuLan
|
| 15 |
from diffrhythm2.cfm import CFM
|
| 16 |
from diffrhythm2.backbones.dit import DiT
|
|
|
|
| 48 |
return token
|
| 49 |
def decode(self, token):
|
| 50 |
return "|".join([self.id2phone[x-1] for x in token])
|
| 51 |
+
|
| 52 |
+
@spaces.GPU
|
| 53 |
def prepare_model(repo_id, device, dtype):
|
| 54 |
diffrhythm2_ckpt_path = hf_hub_download(
|
| 55 |
repo_id=repo_id,
|
|
|
|
| 123 |
lyrics_with_time.append(tokens)
|
| 124 |
return lyrics_with_time
|
| 125 |
|
| 126 |
+
@spaces.GPU
|
| 127 |
def get_audio_prompt(model, audio_file, device, dtype):
|
| 128 |
prompt_wav, sr = torchaudio.load(audio_file)
|
| 129 |
prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
|
|
|
|
| 135 |
style_prompt_embed = model(wavs = prompt_wav)
|
| 136 |
return style_prompt_embed.squeeze(0)
|
| 137 |
|
| 138 |
+
@spaces.GPU
|
| 139 |
def get_text_prompt(model, text, device, dtype):
|
| 140 |
with torch.no_grad():
|
| 141 |
style_prompt_embed = model(texts = [text])
|
| 142 |
return style_prompt_embed.squeeze(0)
|
| 143 |
|
| 144 |
+
@spaces.GPU
|
| 145 |
def make_fake_stereo(audio, sampling_rate):
|
| 146 |
left_channel = audio
|
| 147 |
right_channel = audio.clone()
|
|
|
|
| 153 |
stereo_audio = torch.cat([left_channel, right_channel], dim=0)
|
| 154 |
|
| 155 |
return stereo_audio
|
| 156 |
+
|
| 157 |
+
@spaces.GPU
|
| 158 |
def inference(
|
| 159 |
model,
|
| 160 |
decoder,
|
|
|
|
| 192 |
torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
|
| 193 |
return buffer.getvalue()
|
| 194 |
|
| 195 |
+
@spaces.GPU
|
| 196 |
def inference_stream(
|
| 197 |
model,
|
| 198 |
decoder,
|
|
|
|
| 231 |
dtype=torch.float16
|
| 232 |
diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
|
| 233 |
|
| 234 |
+
|
| 235 |
@spaces.GPU
|
| 236 |
def infer_music(
|
| 237 |
lrc,
|