import torch import torchaudio import os import re import json import random import io from huggingface_hub import hf_hub_download from muq import MuQMuLan from diffrhythm2.cfm import CFM from diffrhythm2.backbones.dit import DiT from bigvgan.model import Generator STRUCT_INFO = { "[start]": 500, "[end]": 501, "[intro]": 502, "[verse]": 503, "[chorus]": 504, "[outro]": 505, "[inst]": 506, "[solo]": 507, "[bridge]": 508, "[hook]": 509, "[break]": 510, "[stop]": 511, "[space]": 512 } class CNENTokenizer(): def __init__(self): curr_path = os.path.abspath(__file__) vocab_path = os.path.join(os.path.dirname((os.path.dirname(curr_path))), "g2p/g2p/vocab.json") with open(vocab_path, 'r') as file: self.phone2id:dict = json.load(file)['vocab'] self.id2phone = {v:k for (k, v) in self.phone2id.items()} from g2p.g2p_generation import chn_eng_g2p self.tokenizer = chn_eng_g2p def encode(self, text): phone, token = self.tokenizer(text) token = [x+1 for x in token] return token def decode(self, token): return "|".join([self.id2phone[x-1] for x in token]) def prepare_model(repo_id, device, dtype): diffrhythm2_ckpt_path = hf_hub_download( repo_id=repo_id, filename="model.safetensors", local_dir="./ckpt", local_files_only=False, ) diffrhythm2_config_path = hf_hub_download( repo_id=repo_id, filename="model.json", local_dir="./ckpt", local_files_only=False, ) with open(diffrhythm2_config_path) as f: model_config = json.load(f) model_config['use_flex_attn'] = False diffrhythm2 = CFM( transformer=DiT( **model_config ), num_channels=model_config['mel_dim'], block_size=model_config['block_size'], ) total_params = sum(p.numel() for p in diffrhythm2.parameters()) diffrhythm2 = diffrhythm2.to(device).to(dtype) if diffrhythm2_ckpt_path.endswith('.safetensors'): from safetensors.torch import load_file ckpt = load_file(diffrhythm2_ckpt_path) else: ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu') diffrhythm2.load_state_dict(ckpt) print(f"Total params: {total_params:,}") # load Mulan mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype) # load frontend lrc_tokenizer = CNENTokenizer() # load decoder decoder_ckpt_path = hf_hub_download( repo_id=repo_id, filename="decoder.bin", local_dir="./ckpt", local_files_only=False, ) decoder_config_path = hf_hub_download( repo_id=repo_id, filename="decoder.json", local_dir="./ckpt", local_files_only=False, ) decoder = Generator(decoder_config_path, decoder_ckpt_path) decoder = decoder.to(device).to(dtype) return diffrhythm2, mulan, lrc_tokenizer, decoder STRUCT_PATTERN = re.compile(r'^\[.*?\]$') def parse_lyrics(lrc_tokenizer, lyrics: str): lyrics_with_time = [] lyrics = lyrics.split("\n") get_start = False for line in lyrics: line = line.strip() if not line: continue struct_flag = STRUCT_PATTERN.match(line) if struct_flag: struct_idx = STRUCT_INFO.get(line.lower(), None) if struct_idx is not None: if struct_idx == STRUCT_INFO['[start]']: get_start = True lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']]) else: continue else: tokens = lrc_tokenizer.encode(line.strip()) tokens = tokens + [STRUCT_INFO['[stop]']] lyrics_with_time.append(tokens) if len(lyrics_with_time) != 0 and not get_start: lyrics_with_time = [[STRUCT_INFO['[start]'], STRUCT_INFO['[stop]']]] + lyrics_with_time return lyrics_with_time @torch.no_grad() def get_audio_prompt(model, audio_file, device, dtype): prompt_wav, sr = torchaudio.load(audio_file) prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000) if prompt_wav.shape[1] > 24000 * 10: start = random.randint(0, prompt_wav.shape[1] - 24000 * 10) prompt_wav = prompt_wav[:, start:start+24000*10] prompt_wav = prompt_wav.mean(dim=0, keepdim=True) with torch.no_grad(): style_prompt_embed = model(wavs = prompt_wav) return style_prompt_embed.squeeze(0).detach() @torch.no_grad() def get_text_prompt(model, text, device, dtype): with torch.no_grad(): style_prompt_embed = model(texts = [text]) return style_prompt_embed.squeeze(0).detach() @torch.no_grad() def make_fake_stereo(audio, sampling_rate): left_channel = audio right_channel = audio.clone() right_channel = right_channel * 0.8 delay_samples = int(0.01 * sampling_rate) right_channel = torch.roll(right_channel, delay_samples) right_channel[:,:delay_samples] = 0 # stereo_audio = np.concatenate([left_channel, right_channel], axis=0) stereo_audio = torch.cat([left_channel, right_channel], dim=0) return stereo_audio def inference( model, decoder, text, style_prompt, duration, cfg_strength=1.0, sample_steps=32, fake_stereo=True, odeint_method='euler', file_type="wav" ): with torch.inference_mode(): latent = model.sample_block_cache( text=text.unsqueeze(0), duration=int(duration * 5), style_prompt=style_prompt.unsqueeze(0), steps=sample_steps, cfg_strength=cfg_strength, odeint_method=odeint_method ) latent = latent.transpose(1, 2).detach() audio = decoder.decode_audio(latent, overlap=5, chunk_size=20).detach() num_channels = 1 audio = audio.float().cpu().detach().squeeze()[None, :] if fake_stereo: audio = make_fake_stereo(audio, decoder.h.sampling_rate) num_channels = 2 if file_type == 'wav': return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time] else: buffer = io.BytesIO() torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type) return buffer.getvalue() def inference_stream( model, decoder, text, style_prompt, duration, cfg_strength=1.0, sample_steps=32, fake_stereo=True, odeint_method='euler', file_type="wav" ): with torch.inference_mode(): for audio in model.sample_cache_stream( decoder=decoder, text=text.unsqueeze(0), duration=int(duration * 5), style_prompt=style_prompt.unsqueeze(0), steps=sample_steps, cfg_strength=cfg_strength, chunk_size=20, overlap=5, odeint_method=odeint_method ): audio = audio.float().cpu().numpy().squeeze()[None, :] if fake_stereo: audio = make_fake_stereo(audio, decoder.h.sampling_rate) # encoded_audio = io.BytesIO() # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav') yield (decoder.h.sampling_rate, audio.T) # [channel, time]