Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import librosa | |
| import torchaudio | |
| import random | |
| import json | |
| from muq import MuQMuLan, MuQ | |
| from mutagen.mp3 import MP3 | |
| import os | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from safetensors.torch import load_file | |
| from diffrhythm.model import DiT, CFM | |
| def vae_sample(mean, scale): | |
| stdev = torch.nn.functional.softplus(scale) + 1e-4 | |
| var = stdev * stdev | |
| logvar = torch.log(var) | |
| latents = torch.randn_like(mean) * stdev + mean | |
| kl = (mean * mean + var - logvar - 1).sum(1).mean() | |
| return latents, kl | |
| def normalize_audio(y, target_dbfs=0): | |
| max_amplitude = torch.max(torch.abs(y)) | |
| target_amplitude = 10.0**(target_dbfs / 20.0) | |
| scale_factor = target_amplitude / max_amplitude | |
| normalized_audio = y * scale_factor | |
| return normalized_audio | |
| def set_audio_channels(audio, target_channels): | |
| if target_channels == 1: | |
| # Convert to mono | |
| audio = audio.mean(1, keepdim=True) | |
| elif target_channels == 2: | |
| # Convert to stereo | |
| if audio.shape[1] == 1: | |
| audio = audio.repeat(1, 2, 1) | |
| elif audio.shape[1] > 2: | |
| audio = audio[:, :2, :] | |
| return audio | |
| class PadCrop(torch.nn.Module): | |
| def __init__(self, n_samples, randomize=True): | |
| super().__init__() | |
| self.n_samples = n_samples | |
| self.randomize = randomize | |
| def __call__(self, signal): | |
| n, s = signal.shape | |
| start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() | |
| end = start + self.n_samples | |
| output = signal.new_zeros([n, self.n_samples]) | |
| output[:, :min(s, self.n_samples)] = signal[:, start:end] | |
| return output | |
| def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): | |
| audio = audio.to(device) | |
| if in_sr != target_sr: | |
| resample_tf = T.Resample(in_sr, target_sr).to(device) | |
| audio = resample_tf(audio) | |
| if target_length is None: | |
| target_length = audio.shape[-1] | |
| audio = PadCrop(target_length, randomize=False)(audio) | |
| # Add batch dimension | |
| if audio.dim() == 1: | |
| audio = audio.unsqueeze(0).unsqueeze(0) | |
| elif audio.dim() == 2: | |
| audio = audio.unsqueeze(0) | |
| audio = set_audio_channels(audio, target_channels) | |
| return audio | |
| def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128): | |
| downsampling_ratio = 2048 | |
| io_channels = 2 | |
| if not chunked: | |
| return vae_model.decode_export(latents) | |
| else: | |
| # chunked decoding | |
| hop_size = chunk_size - overlap | |
| total_size = latents.shape[2] | |
| batch_size = latents.shape[0] | |
| chunks = [] | |
| i = 0 | |
| for i in range(0, total_size - chunk_size + 1, hop_size): | |
| chunk = latents[:, :, i : i + chunk_size] | |
| chunks.append(chunk) | |
| if i + chunk_size != total_size: | |
| # Final chunk | |
| chunk = latents[:, :, -chunk_size:] | |
| chunks.append(chunk) | |
| chunks = torch.stack(chunks) | |
| num_chunks = chunks.shape[0] | |
| # samples_per_latent is just the downsampling ratio | |
| samples_per_latent = downsampling_ratio | |
| # Create an empty waveform, we will populate it with chunks as decode them | |
| y_size = total_size * samples_per_latent | |
| y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device) | |
| for i in range(num_chunks): | |
| x_chunk = chunks[i, :] | |
| # decode the chunk | |
| y_chunk = vae_model.decode_export(x_chunk) | |
| # figure out where to put the audio along the time domain | |
| if i == num_chunks - 1: | |
| # final chunk always goes at the end | |
| t_end = y_size | |
| t_start = t_end - y_chunk.shape[2] | |
| else: | |
| t_start = i * hop_size * samples_per_latent | |
| t_end = t_start + chunk_size * samples_per_latent | |
| # remove the edges of the overlaps | |
| ol = (overlap // 2) * samples_per_latent | |
| chunk_start = 0 | |
| chunk_end = y_chunk.shape[2] | |
| if i > 0: | |
| # no overlap for the start of the first chunk | |
| t_start += ol | |
| chunk_start += ol | |
| if i < num_chunks - 1: | |
| # no overlap for the end of the last chunk | |
| t_end -= ol | |
| chunk_end -= ol | |
| # paste the chunked audio into our y_final output audio | |
| y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] | |
| return y_final | |
| def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128): | |
| downsampling_ratio = 2048 | |
| latent_dim = 128 | |
| if not chunked: | |
| # default behavior. Encode the entire audio in parallel | |
| return vae_model.encode_export(audio) | |
| else: | |
| # CHUNKED ENCODING | |
| # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) | |
| samples_per_latent = downsampling_ratio | |
| total_size = audio.shape[2] # in samples | |
| batch_size = audio.shape[0] | |
| chunk_size *= samples_per_latent # converting metric in latents to samples | |
| overlap *= samples_per_latent # converting metric in latents to samples | |
| hop_size = chunk_size - overlap | |
| chunks = [] | |
| for i in range(0, total_size - chunk_size + 1, hop_size): | |
| chunk = audio[:,:,i:i+chunk_size] | |
| chunks.append(chunk) | |
| if i+chunk_size != total_size: | |
| # Final chunk | |
| chunk = audio[:,:,-chunk_size:] | |
| chunks.append(chunk) | |
| chunks = torch.stack(chunks) | |
| num_chunks = chunks.shape[0] | |
| # Note: y_size might be a different value from the latent length used in diffusion training | |
| # because we can encode audio of varying lengths | |
| # However, the audio should've been padded to a multiple of samples_per_latent by now. | |
| y_size = total_size // samples_per_latent | |
| # Create an empty latent, we will populate it with chunks as we encode them | |
| y_final = torch.zeros((batch_size,latent_dim,y_size)).to(audio.device) | |
| for i in range(num_chunks): | |
| x_chunk = chunks[i,:] | |
| # encode the chunk | |
| y_chunk = vae_model.encode_export(x_chunk) | |
| # figure out where to put the audio along the time domain | |
| if i == num_chunks-1: | |
| # final chunk always goes at the end | |
| t_end = y_size | |
| t_start = t_end - y_chunk.shape[2] | |
| else: | |
| t_start = i * hop_size // samples_per_latent | |
| t_end = t_start + chunk_size // samples_per_latent | |
| # remove the edges of the overlaps | |
| ol = overlap//samples_per_latent//2 | |
| chunk_start = 0 | |
| chunk_end = y_chunk.shape[2] | |
| if i > 0: | |
| # no overlap for the start of the first chunk | |
| t_start += ol | |
| chunk_start += ol | |
| if i < num_chunks-1: | |
| # no overlap for the end of the last chunk | |
| t_end -= ol | |
| chunk_end -= ol | |
| # paste the chunked audio into our y_final output audio | |
| y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] | |
| return y_final | |
| def prepare_model(device): | |
| # prepare cfm model | |
| dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-1_2", filename="cfm_model.pt") | |
| dit_config_path = "./diffrhythm/config/config.json" | |
| with open(dit_config_path) as f: | |
| model_config = json.load(f) | |
| dit_model_cls = DiT | |
| cfm = CFM( | |
| transformer=dit_model_cls(**model_config["model"], max_frames=2048), | |
| num_channels=model_config["model"]['mel_dim'], | |
| ) | |
| cfm = cfm.to(device) | |
| cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False) | |
| # prepare tokenizer | |
| tokenizer = CNENTokenizer() | |
| # prepare muq | |
| muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained") | |
| muq = muq.to(device).eval() | |
| # prepare vae | |
| vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt") | |
| vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device) | |
| # prepare eval model | |
| train_config = OmegaConf.load("./pretrained/eval.yaml") | |
| checkpoint_path = "./pretrained/eval.safetensors" | |
| eval_model = instantiate(train_config.generator).to(device).eval() | |
| state_dict = load_file(checkpoint_path, device="cpu") | |
| eval_model.load_state_dict(state_dict) | |
| eval_muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter") | |
| eval_muq = eval_muq.to(device).eval() | |
| return cfm, tokenizer, muq, vae, eval_model, eval_muq | |
| # for song edit, will be added in the future | |
| def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_model): | |
| sampling_rate = 44100 | |
| downsample_rate = 2048 | |
| io_channels = 2 | |
| if edit: | |
| input_audio, in_sr = torchaudio.load(ref_song) | |
| input_audio = prepare_audio(input_audio, in_sr=in_sr, target_sr=sampling_rate, target_length=None, target_channels=io_channels, device=device) | |
| input_audio = normalize_audio(input_audio, -6) | |
| with torch.no_grad(): | |
| latent = encode_audio(input_audio, vae_model, chunked=True) # [b d t] | |
| mean, scale = latent.chunk(2, dim=1) | |
| prompt, _ = vae_sample(mean, scale) | |
| prompt = prompt.transpose(1, 2) # [b t d] | |
| pred_segments = json.loads(pred_segments) | |
| # import pdb; pdb.set_trace() | |
| pred_frames = [] | |
| for st, et in pred_segments: | |
| sf = 0 if st == -1 else int(st * sampling_rate / downsample_rate) | |
| # if st == -1: | |
| # sf = 0 | |
| # else: | |
| # sf = int(st * sampling_rate / downsample_rate ) | |
| ef = max_frames if et == -1 else int(et * sampling_rate / downsample_rate) | |
| # if et == -1: | |
| # ef = max_frames | |
| # else: | |
| # ef = int(et * sampling_rate / downsample_rate ) | |
| pred_frames.append((sf, ef)) | |
| # import pdb; pdb.set_trace() | |
| return prompt, pred_frames | |
| else: | |
| prompt = torch.zeros(1, max_frames, 64).to(device) | |
| pred_frames = [(0, max_frames)] | |
| return prompt, pred_frames | |
| def get_negative_style_prompt(device): | |
| file_path = "./src/negative_prompt.npy" | |
| vocal_stlye = np.load(file_path) | |
| vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] | |
| vocal_stlye = vocal_stlye.half() | |
| return vocal_stlye | |
| def eval_song(eval_model, eval_muq, songs): | |
| resampled_songs = [torchaudio.functional.resample(song.mean(dim=0, keepdim=True), 44100, 24000) for song in songs] | |
| ssl_list = [] | |
| for i in range(len(resampled_songs)): | |
| output = eval_muq(resampled_songs[i], output_hidden_states=True) | |
| muq_ssl = output["hidden_states"][6] | |
| ssl_list.append(muq_ssl.squeeze(0)) | |
| ssl = torch.stack(ssl_list) | |
| scores_g = eval_model(ssl) | |
| score = torch.mean(scores_g, dim=1) | |
| idx = score.argmax(dim=0) | |
| return songs[idx] | |
| def get_audio_style_prompt(model, wav_path): | |
| vocal_flag = False | |
| mulan = model | |
| audio, _ = librosa.load(wav_path, sr=24000) | |
| audio_len = librosa.get_duration(y=audio, sr=24000) | |
| if audio_len <= 1: | |
| vocal_flag = True | |
| if audio_len > 10: | |
| start_time = int(audio_len // 2 - 5) | |
| wav = audio[start_time*24000:(start_time+10)*24000] | |
| else: | |
| wav = audio | |
| wav = torch.tensor(wav).unsqueeze(0).to(model.device) | |
| with torch.no_grad(): | |
| audio_emb = mulan(wavs = wav) # [1, 512] | |
| audio_emb = audio_emb.half() | |
| return audio_emb, vocal_flag | |
| def get_text_style_prompt(model, text_prompt): | |
| mulan = model | |
| with torch.no_grad(): | |
| text_emb = mulan(texts = text_prompt) # [1, 512] | |
| text_emb = text_emb.half() | |
| return text_emb | |
| def get_style_prompt(model, wav_path=None, prompt=None): | |
| mulan = model | |
| if prompt is not None: | |
| return mulan(texts=prompt).half() | |
| ext = os.path.splitext(wav_path)[-1].lower() | |
| if ext == ".mp3": | |
| meta = MP3(wav_path) | |
| audio_len = meta.info.length | |
| elif ext in [".wav", ".flac"]: | |
| audio_len = librosa.get_duration(path=wav_path) | |
| else: | |
| raise ValueError("Unsupported file format: {}".format(ext)) | |
| if audio_len < 10: | |
| print( | |
| f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds." | |
| ) | |
| assert audio_len >= 10 | |
| mid_time = audio_len // 2 | |
| start_time = mid_time - 5 | |
| wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10) | |
| wav = torch.tensor(wav).unsqueeze(0).to(model.device) | |
| with torch.no_grad(): | |
| audio_emb = mulan(wavs=wav) # [1, 512] | |
| audio_emb = audio_emb | |
| audio_emb = audio_emb.half() | |
| return audio_emb | |
| def parse_lyrics(lyrics: str): | |
| lyrics_with_time = [] | |
| lyrics = lyrics.strip() | |
| for line in lyrics.split("\n"): | |
| try: | |
| time, lyric = line[1:9], line[10:] | |
| lyric = lyric.strip() | |
| mins, secs = time.split(":") | |
| secs = int(mins) * 60 + float(secs) | |
| lyrics_with_time.append((secs, lyric)) | |
| except: | |
| continue | |
| return lyrics_with_time | |
| class CNENTokenizer: | |
| def __init__(self): | |
| with open("./diffrhythm/g2p/g2p/vocab.json", "r", encoding='utf-8') as file: | |
| self.phone2id: dict = json.load(file)["vocab"] | |
| self.id2phone = {v: k for (k, v) in self.phone2id.items()} | |
| from diffrhythm.g2p.g2p_generation import chn_eng_g2p | |
| self.tokenizer = chn_eng_g2p | |
| def encode(self, text): | |
| phone, token = self.tokenizer(text) | |
| token = [x + 1 for x in token] | |
| return token | |
| def decode(self, token): | |
| return "|".join([self.id2phone[x - 1] for x in token]) | |
| def get_lrc_token(max_frames, text, tokenizer, device): | |
| lyrics_shift = 0 | |
| sampling_rate = 44100 | |
| downsample_rate = 2048 | |
| max_secs = max_frames / (sampling_rate / downsample_rate) | |
| comma_token_id = 1 | |
| period_token_id = 2 | |
| lrc_with_time = parse_lyrics(text) | |
| modified_lrc_with_time = [] | |
| for i in range(len(lrc_with_time)): | |
| time, line = lrc_with_time[i] | |
| line_token = tokenizer.encode(line) | |
| modified_lrc_with_time.append((time, line_token)) | |
| lrc_with_time = modified_lrc_with_time | |
| lrc_with_time = [ | |
| (time_start, line) | |
| for (time_start, line) in lrc_with_time | |
| if time_start < max_secs | |
| ] | |
| if max_frames == 2048: | |
| lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time | |
| normalized_start_time = 0.0 | |
| lrc = torch.zeros((max_frames,), dtype=torch.long) | |
| tokens_count = 0 | |
| last_end_pos = 0 | |
| for time_start, line in lrc_with_time: | |
| tokens = [ | |
| token if token != period_token_id else comma_token_id for token in line | |
| ] + [period_token_id] | |
| tokens = torch.tensor(tokens, dtype=torch.long) | |
| num_tokens = tokens.shape[0] | |
| gt_frame_start = int(time_start * sampling_rate / downsample_rate) | |
| frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift)) | |
| frame_start = max(gt_frame_start - frame_shift, last_end_pos) | |
| frame_len = min(num_tokens, max_frames - frame_start) | |
| lrc[frame_start : frame_start + frame_len] = tokens[:frame_len] | |
| tokens_count += num_tokens | |
| last_end_pos = frame_start + frame_len | |
| lrc_emb = lrc.unsqueeze(0).to(device) | |
| normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device) | |
| normalized_start_time = normalized_start_time.half() | |
| return lrc_emb, normalized_start_time | |
| def load_checkpoint(model, ckpt_path, device, use_ema=True): | |
| model = model.half() | |
| ckpt_type = ckpt_path.split(".")[-1] | |
| if ckpt_type == "safetensors": | |
| from safetensors.torch import load_file | |
| checkpoint = load_file(ckpt_path) | |
| else: | |
| checkpoint = torch.load(ckpt_path, weights_only=True) | |
| if use_ema: | |
| if ckpt_type == "safetensors": | |
| checkpoint = {"ema_model_state_dict": checkpoint} | |
| checkpoint["model_state_dict"] = { | |
| k.replace("ema_model.", ""): v | |
| for k, v in checkpoint["ema_model_state_dict"].items() | |
| if k not in ["initted", "step"] | |
| } | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| else: | |
| if ckpt_type == "safetensors": | |
| checkpoint = {"model_state_dict": checkpoint} | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| return model.to(device) | |