OrangeJR commited on
Commit
251501d
·
1 Parent(s): 1f98bdf

simplify app.py

Browse files
Files changed (2) hide show
  1. app.py +9 -220
  2. diffrhythm2/utils.py +217 -0
app.py CHANGED
@@ -1,229 +1,19 @@
1
  import gradio as gr
2
 
3
- import json
4
  import torch
5
- import torchaudio
6
  import json
7
- import os
8
  import random
9
  import numpy as np
10
- import io
11
- import pydub
12
  import base64
13
  import spaces
14
- from muq import MuQMuLan
15
- from diffrhythm2.cfm import CFM
16
- from diffrhythm2.backbones.dit import DiT
17
- from bigvgan.model import Generator
18
- from huggingface_hub import hf_hub_download
19
-
20
- STRUCT_INFO = {
21
- "[start]": 500,
22
- "[end]": 501,
23
- "[intro]": 502,
24
- "[verse]": 503,
25
- "[chorus]": 504,
26
- "[outro]": 505,
27
- "[inst]": 506,
28
- "[solo]": 507,
29
- "[bridge]": 508,
30
- "[hook]": 509,
31
- "[break]": 510,
32
- "[stop]": 511,
33
- "[space]": 512
34
- }
35
-
36
- class CNENTokenizer():
37
- def __init__(self):
38
- curr_path = os.path.abspath(__file__)
39
- vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json")
40
- with open(vocab_path, 'r') as file:
41
- self.phone2id:dict = json.load(file)['vocab']
42
- self.id2phone = {v:k for (k, v) in self.phone2id.items()}
43
- from g2p.g2p_generation import chn_eng_g2p
44
- self.tokenizer = chn_eng_g2p
45
- def encode(self, text):
46
- phone, token = self.tokenizer(text)
47
- token = [x+1 for x in token]
48
- return token
49
- def decode(self, token):
50
- return "|".join([self.id2phone[x-1] for x in token])
51
-
52
- @spaces.GPU
53
- def prepare_model(repo_id, device, dtype):
54
- diffrhythm2_ckpt_path = hf_hub_download(
55
- repo_id=repo_id,
56
- filename="model.safetensors",
57
- local_dir="./ckpt",
58
- local_files_only=False,
59
- )
60
- diffrhythm2_config_path = hf_hub_download(
61
- repo_id=repo_id,
62
- filename="model.json",
63
- local_dir="./ckpt",
64
- local_files_only=False,
65
- )
66
- with open(diffrhythm2_config_path) as f:
67
- model_config = json.load(f)
68
-
69
- model_config['use_flex_attn'] = False
70
- diffrhythm2 = CFM(
71
- transformer=DiT(
72
- **model_config
73
- ),
74
- num_channels=model_config['mel_dim'],
75
- block_size=model_config['block_size'],
76
- )
77
-
78
- total_params = sum(p.numel() for p in diffrhythm2.parameters())
79
-
80
- diffrhythm2 = diffrhythm2.to(device).to(dtype)
81
- if diffrhythm2_ckpt_path.endswith('.safetensors'):
82
- from safetensors.torch import load_file
83
- ckpt = load_file(diffrhythm2_ckpt_path)
84
- else:
85
- ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
86
- diffrhythm2.load_state_dict(ckpt)
87
- print(f"Total params: {total_params:,}")
88
-
89
- # load Mulan
90
- mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype)
91
-
92
- # load frontend
93
- lrc_tokenizer = CNENTokenizer()
94
-
95
- # load decoder
96
- decoder_ckpt_path = hf_hub_download(
97
- repo_id=repo_id,
98
- filename="decoder.bin",
99
- local_dir="./ckpt",
100
- local_files_only=False,
101
- )
102
- decoder_config_path = hf_hub_download(
103
- repo_id=repo_id,
104
- filename="decoder.json",
105
- local_dir="./ckpt",
106
- local_files_only=False,
107
- )
108
- decoder = Generator(decoder_config_path, decoder_ckpt_path)
109
- decoder = decoder.to(device).to(dtype)
110
-
111
- return diffrhythm2, mulan, lrc_tokenizer, decoder
112
-
113
- def parse_lyrics(lyrics: str):
114
- lyrics_with_time = []
115
- lyrics = lyrics.split("\n")
116
- for line in lyrics:
117
- struct_idx = STRUCT_INFO.get(line, None)
118
- if struct_idx is not None:
119
- lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
120
- else:
121
- tokens = lrc_tokenizer.encode(line.strip())
122
- tokens = tokens + [STRUCT_INFO['[stop]']]
123
- lyrics_with_time.append(tokens)
124
- return lyrics_with_time
125
-
126
- @spaces.GPU
127
- def get_audio_prompt(model, audio_file, device, dtype):
128
- prompt_wav, sr = torchaudio.load(audio_file)
129
- prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
130
- if prompt_wav.shape[1] > 24000 * 10:
131
- start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
132
- prompt_wav = prompt_wav[:, start:start+24000*10]
133
- prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
134
- with torch.no_grad():
135
- style_prompt_embed = model(wavs = prompt_wav)
136
- return style_prompt_embed.squeeze(0).detach()
137
-
138
- @spaces.GPU
139
- def get_text_prompt(model, text, device, dtype):
140
- with torch.no_grad():
141
- style_prompt_embed = model(texts = [text])
142
- return style_prompt_embed.squeeze(0).detach()
143
-
144
- @spaces.GPU
145
- def make_fake_stereo(audio, sampling_rate):
146
- left_channel = audio
147
- right_channel = audio.clone()
148
- right_channel = right_channel * 0.8
149
- delay_samples = int(0.01 * sampling_rate)
150
- right_channel = torch.roll(right_channel, delay_samples)
151
- right_channel[:,:delay_samples] = 0
152
- # stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
153
- stereo_audio = torch.cat([left_channel, right_channel], dim=0)
154
-
155
- return stereo_audio
156
-
157
- @spaces.GPU
158
- def inference(
159
- model,
160
- decoder,
161
- text,
162
- style_prompt,
163
- duration,
164
- cfg_strength=1.0,
165
- sample_steps=32,
166
- fake_stereo=True,
167
- odeint_method='euler',
168
- file_type="wav"
169
- ):
170
- with torch.inference_mode():
171
- latent = model.sample_block_cache(
172
- text=text.unsqueeze(0),
173
- duration=int(duration * 5),
174
- style_prompt=style_prompt.unsqueeze(0),
175
- steps=sample_steps,
176
- cfg_strength=cfg_strength,
177
- odeint_method=odeint_method
178
- )
179
- latent = latent.transpose(1, 2).detach()
180
- audio = decoder.decode_audio(latent, overlap=5, chunk_size=20).detach()
181
-
182
- num_channels = 1
183
- audio = audio.float().cpu().detach().squeeze()[None, :]
184
- if fake_stereo:
185
- audio = make_fake_stereo(audio, decoder.h.sampling_rate)
186
- num_channels = 2
187
-
188
- if file_type == 'wav':
189
- return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time]
190
- else:
191
- buffer = io.BytesIO()
192
- torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
193
- return buffer.getvalue()
194
-
195
- @spaces.GPU
196
- def inference_stream(
197
- model,
198
- decoder,
199
- text,
200
- style_prompt,
201
- duration,
202
- cfg_strength=1.0,
203
- sample_steps=32,
204
- fake_stereo=True,
205
- odeint_method='euler',
206
- file_type="wav"
207
- ):
208
- with torch.inference_mode():
209
- for audio in model.sample_cache_stream(
210
- decoder=decoder,
211
- text=text.unsqueeze(0),
212
- duration=int(duration * 5),
213
- style_prompt=style_prompt.unsqueeze(0),
214
- steps=sample_steps,
215
- cfg_strength=cfg_strength,
216
- chunk_size=20,
217
- overlap=5,
218
- odeint_method=odeint_method
219
- ):
220
- audio = audio.float().cpu().numpy().squeeze()[None, :]
221
- if fake_stereo:
222
- audio = make_fake_stereo(audio, decoder.h.sampling_rate)
223
- # encoded_audio = io.BytesIO()
224
- # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav')
225
- yield (decoder.h.sampling_rate, audio.T) # [channel, time]
226
-
227
 
228
  lrc_tokenizer = None
229
  MAX_SEED = np.iinfo(np.int32).max
@@ -231,7 +21,6 @@ device='cuda'
231
  dtype=torch.float16
232
  diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
233
 
234
-
235
  @spaces.GPU
236
  def infer_music(
237
  lrc,
@@ -251,7 +40,7 @@ def infer_music(
251
  torch.manual_seed(seed)
252
  print(seed, current_prompt_type)
253
  try:
254
- lrc_prompt = parse_lyrics(lrc)
255
  lrc_prompt = torch.tensor(sum(lrc_prompt, []), dtype=torch.long, device=device)
256
  if current_prompt_type == "audio":
257
  style_prompt = get_audio_prompt(mulan, audio_prompt, device, dtype)
 
1
  import gradio as gr
2
 
 
3
  import torch
 
4
  import json
 
5
  import random
6
  import numpy as np
 
 
7
  import base64
8
  import spaces
9
+ from diffrhythm2.utils import (
10
+ prepare_model,
11
+ parse_lyrics,
12
+ get_audio_prompt,
13
+ get_text_prompt,
14
+ inference,
15
+ inference_stream
16
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  lrc_tokenizer = None
19
  MAX_SEED = np.iinfo(np.int32).max
 
21
  dtype=torch.float16
22
  diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
23
 
 
24
  @spaces.GPU
25
  def infer_music(
26
  lrc,
 
40
  torch.manual_seed(seed)
41
  print(seed, current_prompt_type)
42
  try:
43
+ lrc_prompt = parse_lyrics(lrc_tokenizer, lrc)
44
  lrc_prompt = torch.tensor(sum(lrc_prompt, []), dtype=torch.long, device=device)
45
  if current_prompt_type == "audio":
46
  style_prompt = get_audio_prompt(mulan, audio_prompt, device, dtype)
diffrhythm2/utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import os
4
+ import json
5
+ import random
6
+ import io
7
+ from huggingface_hub import hf_hub_download
8
+ from muq import MuQMuLan
9
+ from diffrhythm2.cfm import CFM
10
+ from diffrhythm2.backbones.dit import DiT
11
+ from bigvgan.model import Generator
12
+
13
+
14
+ STRUCT_INFO = {
15
+ "[start]": 500,
16
+ "[end]": 501,
17
+ "[intro]": 502,
18
+ "[verse]": 503,
19
+ "[chorus]": 504,
20
+ "[outro]": 505,
21
+ "[inst]": 506,
22
+ "[solo]": 507,
23
+ "[bridge]": 508,
24
+ "[hook]": 509,
25
+ "[break]": 510,
26
+ "[stop]": 511,
27
+ "[space]": 512
28
+ }
29
+
30
+ class CNENTokenizer():
31
+ def __init__(self):
32
+ curr_path = os.path.abspath(__file__)
33
+ vocab_path = os.path.join(os.path.dirname((os.path.dirname(curr_path))), "g2p/g2p/vocab.json")
34
+ with open(vocab_path, 'r') as file:
35
+ self.phone2id:dict = json.load(file)['vocab']
36
+ self.id2phone = {v:k for (k, v) in self.phone2id.items()}
37
+ from g2p.g2p_generation import chn_eng_g2p
38
+ self.tokenizer = chn_eng_g2p
39
+ def encode(self, text):
40
+ phone, token = self.tokenizer(text)
41
+ token = [x+1 for x in token]
42
+ return token
43
+ def decode(self, token):
44
+ return "|".join([self.id2phone[x-1] for x in token])
45
+
46
+ def prepare_model(repo_id, device, dtype):
47
+ diffrhythm2_ckpt_path = hf_hub_download(
48
+ repo_id=repo_id,
49
+ filename="model.safetensors",
50
+ local_dir="./ckpt",
51
+ local_files_only=False,
52
+ )
53
+ diffrhythm2_config_path = hf_hub_download(
54
+ repo_id=repo_id,
55
+ filename="model.json",
56
+ local_dir="./ckpt",
57
+ local_files_only=False,
58
+ )
59
+ with open(diffrhythm2_config_path) as f:
60
+ model_config = json.load(f)
61
+
62
+ model_config['use_flex_attn'] = False
63
+ diffrhythm2 = CFM(
64
+ transformer=DiT(
65
+ **model_config
66
+ ),
67
+ num_channels=model_config['mel_dim'],
68
+ block_size=model_config['block_size'],
69
+ )
70
+
71
+ total_params = sum(p.numel() for p in diffrhythm2.parameters())
72
+
73
+ diffrhythm2 = diffrhythm2.to(device).to(dtype)
74
+ if diffrhythm2_ckpt_path.endswith('.safetensors'):
75
+ from safetensors.torch import load_file
76
+ ckpt = load_file(diffrhythm2_ckpt_path)
77
+ else:
78
+ ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
79
+ diffrhythm2.load_state_dict(ckpt)
80
+ print(f"Total params: {total_params:,}")
81
+
82
+ # load Mulan
83
+ mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype)
84
+
85
+ # load frontend
86
+ lrc_tokenizer = CNENTokenizer()
87
+
88
+ # load decoder
89
+ decoder_ckpt_path = hf_hub_download(
90
+ repo_id=repo_id,
91
+ filename="decoder.bin",
92
+ local_dir="./ckpt",
93
+ local_files_only=False,
94
+ )
95
+ decoder_config_path = hf_hub_download(
96
+ repo_id=repo_id,
97
+ filename="decoder.json",
98
+ local_dir="./ckpt",
99
+ local_files_only=False,
100
+ )
101
+ decoder = Generator(decoder_config_path, decoder_ckpt_path)
102
+ decoder = decoder.to(device).to(dtype)
103
+
104
+ return diffrhythm2, mulan, lrc_tokenizer, decoder
105
+
106
+ def parse_lyrics(lrc_tokenizer, lyrics: str):
107
+ lyrics_with_time = []
108
+ lyrics = lyrics.split("\n")
109
+ for line in lyrics:
110
+ struct_idx = STRUCT_INFO.get(line, None)
111
+ if struct_idx is not None:
112
+ lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
113
+ else:
114
+ tokens = lrc_tokenizer.encode(line.strip())
115
+ tokens = tokens + [STRUCT_INFO['[stop]']]
116
+ lyrics_with_time.append(tokens)
117
+ return lyrics_with_time
118
+
119
+ @torch.no_grad()
120
+ def get_audio_prompt(model, audio_file, device, dtype):
121
+ prompt_wav, sr = torchaudio.load(audio_file)
122
+ prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
123
+ if prompt_wav.shape[1] > 24000 * 10:
124
+ start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
125
+ prompt_wav = prompt_wav[:, start:start+24000*10]
126
+ prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
127
+ with torch.no_grad():
128
+ style_prompt_embed = model(wavs = prompt_wav)
129
+ return style_prompt_embed.squeeze(0).detach()
130
+
131
+ @torch.no_grad()
132
+ def get_text_prompt(model, text, device, dtype):
133
+ with torch.no_grad():
134
+ style_prompt_embed = model(texts = [text])
135
+ return style_prompt_embed.squeeze(0).detach()
136
+
137
+ @torch.no_grad()
138
+ def make_fake_stereo(audio, sampling_rate):
139
+ left_channel = audio
140
+ right_channel = audio.clone()
141
+ right_channel = right_channel * 0.8
142
+ delay_samples = int(0.01 * sampling_rate)
143
+ right_channel = torch.roll(right_channel, delay_samples)
144
+ right_channel[:,:delay_samples] = 0
145
+ # stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
146
+ stereo_audio = torch.cat([left_channel, right_channel], dim=0)
147
+
148
+ return stereo_audio
149
+
150
+
151
+ def inference(
152
+ model,
153
+ decoder,
154
+ text,
155
+ style_prompt,
156
+ duration,
157
+ cfg_strength=1.0,
158
+ sample_steps=32,
159
+ fake_stereo=True,
160
+ odeint_method='euler',
161
+ file_type="wav"
162
+ ):
163
+ with torch.inference_mode():
164
+ latent = model.sample_block_cache(
165
+ text=text.unsqueeze(0),
166
+ duration=int(duration * 5),
167
+ style_prompt=style_prompt.unsqueeze(0),
168
+ steps=sample_steps,
169
+ cfg_strength=cfg_strength,
170
+ odeint_method=odeint_method
171
+ )
172
+ latent = latent.transpose(1, 2).detach()
173
+ audio = decoder.decode_audio(latent, overlap=5, chunk_size=20).detach()
174
+
175
+ num_channels = 1
176
+ audio = audio.float().cpu().detach().squeeze()[None, :]
177
+ if fake_stereo:
178
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
179
+ num_channels = 2
180
+
181
+ if file_type == 'wav':
182
+ return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time]
183
+ else:
184
+ buffer = io.BytesIO()
185
+ torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
186
+ return buffer.getvalue()
187
+
188
+ def inference_stream(
189
+ model,
190
+ decoder,
191
+ text,
192
+ style_prompt,
193
+ duration,
194
+ cfg_strength=1.0,
195
+ sample_steps=32,
196
+ fake_stereo=True,
197
+ odeint_method='euler',
198
+ file_type="wav"
199
+ ):
200
+ with torch.inference_mode():
201
+ for audio in model.sample_cache_stream(
202
+ decoder=decoder,
203
+ text=text.unsqueeze(0),
204
+ duration=int(duration * 5),
205
+ style_prompt=style_prompt.unsqueeze(0),
206
+ steps=sample_steps,
207
+ cfg_strength=cfg_strength,
208
+ chunk_size=20,
209
+ overlap=5,
210
+ odeint_method=odeint_method
211
+ ):
212
+ audio = audio.float().cpu().numpy().squeeze()[None, :]
213
+ if fake_stereo:
214
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
215
+ # encoded_audio = io.BytesIO()
216
+ # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav')
217
+ yield (decoder.h.sampling_rate, audio.T) # [channel, time]