ASLP-lab commited on
Commit
3d36e93
·
verified ·
1 Parent(s): c0d12aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import io
11
  import pydub
12
  import base64
 
13
  from muq import MuQMuLan
14
  from diffrhythm2.cfm import CFM
15
  from diffrhythm2.backbones.dit import DiT
@@ -47,7 +48,8 @@ class CNENTokenizer():
47
  return token
48
  def decode(self, token):
49
  return "|".join([self.id2phone[x-1] for x in token])
50
-
 
51
  def prepare_model(repo_id, device, dtype):
52
  diffrhythm2_ckpt_path = hf_hub_download(
53
  repo_id=repo_id,
@@ -121,6 +123,7 @@ def parse_lyrics(lyrics: str):
121
  lyrics_with_time.append(tokens)
122
  return lyrics_with_time
123
 
 
124
  def get_audio_prompt(model, audio_file, device, dtype):
125
  prompt_wav, sr = torchaudio.load(audio_file)
126
  prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
@@ -132,11 +135,13 @@ def get_audio_prompt(model, audio_file, device, dtype):
132
  style_prompt_embed = model(wavs = prompt_wav)
133
  return style_prompt_embed.squeeze(0)
134
 
 
135
  def get_text_prompt(model, text, device, dtype):
136
  with torch.no_grad():
137
  style_prompt_embed = model(texts = [text])
138
  return style_prompt_embed.squeeze(0)
139
 
 
140
  def make_fake_stereo(audio, sampling_rate):
141
  left_channel = audio
142
  right_channel = audio.clone()
@@ -148,7 +153,8 @@ def make_fake_stereo(audio, sampling_rate):
148
  stereo_audio = torch.cat([left_channel, right_channel], dim=0)
149
 
150
  return stereo_audio
151
-
 
152
  def inference(
153
  model,
154
  decoder,
@@ -186,6 +192,7 @@ def inference(
186
  torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
187
  return buffer.getvalue()
188
 
 
189
  def inference_stream(
190
  model,
191
  decoder,
@@ -224,7 +231,7 @@ device='cuda'
224
  dtype=torch.float16
225
  diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
226
 
227
- import spaces
228
  @spaces.GPU
229
  def infer_music(
230
  lrc,
 
10
  import io
11
  import pydub
12
  import base64
13
+ import spaces
14
  from muq import MuQMuLan
15
  from diffrhythm2.cfm import CFM
16
  from diffrhythm2.backbones.dit import DiT
 
48
  return token
49
  def decode(self, token):
50
  return "|".join([self.id2phone[x-1] for x in token])
51
+
52
+ @spaces.GPU
53
  def prepare_model(repo_id, device, dtype):
54
  diffrhythm2_ckpt_path = hf_hub_download(
55
  repo_id=repo_id,
 
123
  lyrics_with_time.append(tokens)
124
  return lyrics_with_time
125
 
126
+ @spaces.GPU
127
  def get_audio_prompt(model, audio_file, device, dtype):
128
  prompt_wav, sr = torchaudio.load(audio_file)
129
  prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
 
135
  style_prompt_embed = model(wavs = prompt_wav)
136
  return style_prompt_embed.squeeze(0)
137
 
138
+ @spaces.GPU
139
  def get_text_prompt(model, text, device, dtype):
140
  with torch.no_grad():
141
  style_prompt_embed = model(texts = [text])
142
  return style_prompt_embed.squeeze(0)
143
 
144
+ @spaces.GPU
145
  def make_fake_stereo(audio, sampling_rate):
146
  left_channel = audio
147
  right_channel = audio.clone()
 
153
  stereo_audio = torch.cat([left_channel, right_channel], dim=0)
154
 
155
  return stereo_audio
156
+
157
+ @spaces.GPU
158
  def inference(
159
  model,
160
  decoder,
 
192
  torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
193
  return buffer.getvalue()
194
 
195
+ @spaces.GPU
196
  def inference_stream(
197
  model,
198
  decoder,
 
231
  dtype=torch.float16
232
  diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
233
 
234
+
235
  @spaces.GPU
236
  def infer_music(
237
  lrc,