import gradio as gr import json import torch import torchaudio import json import os import random import numpy as np import io import pydub import base64 from muq import MuQMuLan from diffrhythm2.cfm import CFM from diffrhythm2.backbones.dit import DiT from bigvgan.model import Generator from huggingface_hub import hf_hub_download STRUCT_INFO = { "[start]": 500, "[end]": 501, "[intro]": 502, "[verse]": 503, "[chorus]": 504, "[outro]": 505, "[inst]": 506, "[solo]": 507, "[bridge]": 508, "[hook]": 509, "[break]": 510, "[stop]": 511, "[space]": 512 } class CNENTokenizer(): def __init__(self): curr_path = os.path.abspath(__file__) vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json") with open(vocab_path, 'r') as file: self.phone2id:dict = json.load(file)['vocab'] self.id2phone = {v:k for (k, v) in self.phone2id.items()} from g2p.g2p_generation import chn_eng_g2p self.tokenizer = chn_eng_g2p def encode(self, text): phone, token = self.tokenizer(text) token = [x+1 for x in token] return token def decode(self, token): return "|".join([self.id2phone[x-1] for x in token]) def prepare_model(repo_id, device, dtype): diffrhythm2_ckpt_path = hf_hub_download( repo_id=repo_id, filename="model.safetensors", local_dir="./ckpt", local_files_only=False, ) diffrhythm2_config_path = hf_hub_download( repo_id=repo_id, filename="model.json", local_dir="./ckpt", local_files_only=False, ) with open(diffrhythm2_config_path) as f: model_config = json.load(f) model_config['use_flex_attn'] = False diffrhythm2 = CFM( transformer=DiT( **model_config ), num_channels=model_config['mel_dim'], block_size=model_config['block_size'], ) total_params = sum(p.numel() for p in diffrhythm2.parameters()) diffrhythm2 = diffrhythm2.to(device).to(dtype) if diffrhythm2_ckpt_path.endswith('.safetensors'): from safetensors.torch import load_file ckpt = load_file(diffrhythm2_ckpt_path) else: ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu') diffrhythm2.load_state_dict(ckpt) print(f"Total params: {total_params:,}") # load Mulan mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype) # load frontend lrc_tokenizer = CNENTokenizer() # load decoder decoder_ckpt_path = hf_hub_download( repo_id=repo_id, filename="decoder.bin", local_dir="./ckpt", local_files_only=False, ) decoder_config_path = hf_hub_download( repo_id=repo_id, filename="decoder.json", local_dir="./ckpt", local_files_only=False, ) decoder = Generator(decoder_config_path, decoder_ckpt_path) decoder = decoder.to(device).to(dtype) return diffrhythm2, mulan, lrc_tokenizer, decoder def parse_lyrics(lyrics: str): lyrics_with_time = [] lyrics = lyrics.split("\n") for line in lyrics: struct_idx = STRUCT_INFO.get(line, None) if struct_idx is not None: lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']]) else: tokens = lrc_tokenizer.encode(line.strip()) tokens = tokens + [STRUCT_INFO['[stop]']] lyrics_with_time.append(tokens) return lyrics_with_time def get_audio_prompt(model, audio_file, device, dtype): prompt_wav, sr = torchaudio.load(audio_file) prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000) if prompt_wav.shape[1] > 24000 * 10: start = random.randint(0, prompt_wav.shape[1] - 24000 * 10) prompt_wav = prompt_wav[:, start:start+24000*10] prompt_wav = prompt_wav.mean(dim=0, keepdim=True) with torch.no_grad(): style_prompt_embed = model(wavs = prompt_wav) return style_prompt_embed.squeeze(0) def get_text_prompt(model, text, device, dtype): with torch.no_grad(): style_prompt_embed = model(texts = [text]) return style_prompt_embed.squeeze(0) def make_fake_stereo(audio, sampling_rate): left_channel = audio right_channel = audio.clone() right_channel = right_channel * 0.8 delay_samples = int(0.01 * sampling_rate) right_channel = torch.roll(right_channel, delay_samples) right_channel[:,:delay_samples] = 0 # stereo_audio = np.concatenate([left_channel, right_channel], axis=0) stereo_audio = torch.cat([left_channel, right_channel], dim=0) return stereo_audio def inference( model, decoder, text, style_prompt, duration, cfg_strength=1.0, sample_steps=32, fake_stereo=True, odeint_method='euler', file_type="wav" ): with torch.inference_mode(): latent = model.sample_block_cache( text=text.unsqueeze(0), duration=int(duration * 5), style_prompt=style_prompt.unsqueeze(0), steps=sample_steps, cfg_strength=cfg_strength, odeint_method=odeint_method ) latent = latent.transpose(1, 2) audio = decoder.decode_audio(latent, overlap=5, chunk_size=20) num_channels = 1 audio = audio.float().cpu().squeeze()[None, :] if fake_stereo: audio = make_fake_stereo(audio, decoder.h.sampling_rate) num_channels = 2 if file_type == 'wav': return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time] else: buffer = io.BytesIO() torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type) return buffer.getvalue() def inference_stream( model, decoder, text, style_prompt, duration, cfg_strength=1.0, sample_steps=32, fake_stereo=True, odeint_method='euler', file_type="wav" ): with torch.inference_mode(): for audio in model.sample_cache_stream( decoder=decoder, text=text.unsqueeze(0), duration=int(duration * 5), style_prompt=style_prompt.unsqueeze(0), steps=sample_steps, cfg_strength=cfg_strength, chunk_size=20, overlap=5, odeint_method=odeint_method ): audio = audio.float().cpu().numpy().squeeze()[None, :] if fake_stereo: audio = make_fake_stereo(audio, decoder.h.sampling_rate) # encoded_audio = io.BytesIO() # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav') yield (decoder.h.sampling_rate, audio.T) # [channel, time] lrc_tokenizer = None MAX_SEED = np.iinfo(np.int32).max device='cuda' dtype=torch.float16 diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype) import spaces @spaces.GPU def infer_music( lrc, current_prompt_type, audio_prompt=None, text_prompt=None, seed=42, randomize_seed=False, steps=16, cfg_strength=1.0, file_type='wav', odeint_method='euler', device='cuda' ): if randomize_seed: seed = random.randint(0, MAX_SEED) torch.manual_seed(seed) print(seed, current_prompt_type) try: lrc_prompt = parse_lyrics(lrc) lrc_prompt = torch.tensor(sum(lrc_prompt, []), dtype=torch.long, device=device) if current_prompt_type == "audio": style_prompt = get_audio_prompt(mulan, audio_prompt, device, dtype) else: style_prompt = get_text_prompt(mulan, text_prompt, device, dtype) except Exception as e: raise gr.Error(f"Error: {str(e)}") style_prompt = style_prompt.to(dtype) generate_song = inference( model=diffrhythm2, decoder=decoder, text=lrc_prompt, style_prompt=style_prompt, sample_steps=steps, cfg_strength=cfg_strength, odeint_method=odeint_method, duration=240, file_type=file_type ) return generate_song # for block in inference_stream( # model=diffrhythm2, # decoder=decoder, # text=lrc_prompt, # style_prompt=style_prompt, # sample_steps=steps, # cfg_strength=cfg_strength, # odeint_method=odeint_method, # duration=240, # file_type=file_type # ): # yield block css = """ /* 固定文本域高度并强制滚动条 */ .lyrics-scroll-box textarea { height: 405px !important; /* 固定高度 */ max-height: 500px !important; /* 最大高度 */ overflow-y: auto !important; /* 垂直滚动 */ white-space: pre-wrap; /* 保留换行 */ line-height: 1.5; /* 行高优化 */ } .gr-examples { background: transparent !important; border: 1px solid #e0e0e0 !important; border-radius: 8px; margin: 1rem 0 !important; padding: 1rem !important; } """ import base64 def image_to_base64(path): with open(path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') with gr.Blocks(css=css) as demo: gr.HTML(f"""
""") with gr.Tabs() as tabs: # page 1 with gr.Tab("Music Generate", id=0): with gr.Row(): with gr.Column(): lrc = gr.Textbox( label="Lyrics", placeholder="Input the full lyrics", lines=12, max_lines=50, elem_classes="lyrics-scroll-box", value="""[start] [intro] [verse] Thought I heard your voice yesterday When I turned around to say That I loved you baby I realize it was juss my mind Played tricks on me And it seems colder lately at night And I try to sleep with the lights on Every time the phone rings I pray to God it's you And I just can't believe That we're through [chorus] I miss you There's no other way to say it And I can't deny it I miss you It's so easy to see I miss you and me [verse] Is it turning over this time Have we really changed our minds about each other's love All the feelings that we used to share I refuse to believe That you don't care [chorus] I miss you There's no other way to say it And I and I can't deny it I miss you [verse] It's so easy to see I've got to gather myself as together I've been through worst kinds of weather If it's over now [outro]""" ) current_prompt_type = gr.State(value="text") with gr.Tabs() as inside_tabs: with gr.Tab("Text Prompt"): text_prompt = gr.Textbox( label="Text Prompt", value="Pop, Piano, Bass, Drums, Happy", placeholder="Enter the Text Prompt, eg: emotional piano pop", ) with gr.Tab("Audio Prompt"): audio_prompt = gr.Audio(label="Audio Prompt", type="filepath") def update_prompt_type(evt: gr.SelectData): return "text" if evt.index == 0 else "audio" inside_tabs.select( fn=update_prompt_type, outputs=current_prompt_type ) with gr.Column(): with gr.Accordion("Best Practices Guide", open=True): gr.Markdown(""" 1. **Lyrics Format Requirements** - Each line must follow: `Lyric content` - Example of valid format: ``` [intro] [verse] Thought I heard your voice yesterday When I turned around to say ``` 2. **Audio Prompt Requirements** - Reference audio should be ≥ 1 second, Audio >10 seconds will be randomly clipped into 10 seconds - For optimal results, the 10-second clips should be carefully selected - Shorter clips may lead to incoherent generation 3. **Supported Languages** - Chinese and English """) lyrics_btn = gr.Button("Generate", variant="primary") # audio_output = gr.Gallery(label="Audio Results") audio_output = gr.Audio(label="Audio Result", elem_id="audio_output") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) steps = gr.Slider( minimum=10, maximum=100, value=16, step=1, label="Diffusion Steps", interactive=True, elem_id="step_slider" ) cfg_strength = gr.Slider( minimum=1, maximum=10, value=1.0, step=0.5, label="CFG Strength", interactive=True, elem_id="step_slider" ) odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler") file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3") # gr.Examples( # examples=[ # ["src/prompt/classic_cn.wav"], # ["src/prompt/classic_en.wav"], # ["src/prompt/country_cn.wav"], # ["src/prompt/country_en.wav"], # ["src/prompt/jazz_cn.wav"], # ["src/prompt/jazz_en.wav"], # ["src/prompt/pop_cn.wav"], # ["src/prompt/pop_en.wav"], # ["src/prompt/rap_cn.wav"], # ["src/prompt/rap_en.wav"], # ["src/prompt/rock_cn.wav"], # ["src/prompt/rock_en.wav"] # ], # inputs=[audio_prompt], # label="Audio Examples", # examples_per_page=12, # elem_id="audio-examples-container" # ) # gr.Examples( # examples=[ # ["Pop Emotional Piano"], # ["流行 情感 钢琴"], # ["Indie folk ballad, coming-of-age themes, acoustic guitar picking with harmonica interludes"], # ["独立民谣, 成长主题, 原声吉他弹奏与口琴间奏"] # ], # inputs=[text_prompt], # label="Text Examples", # examples_per_page=4, # elem_id="text-examples-container" # ) # gr.Examples( # examples=[ # ["""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""], # ["""[00:05.00]Stardust whispers in your eyes\n[00:09.30]Moonlight paints our silhouettes\n[00:13.75]Tides bring secrets from the deep\n[00:18.20]Where forever's breath is kept\n[00:22.90]We dance through constellations' maze\n[00:27.15]Footprints melt in cosmic waves\n[00:31.65]Horizons hum our silent vow\n[00:36.10]Time unravels here and now\n[00:40.85]Eternal embers in the night oh oh oh\n[00:45.25]Healing scars with liquid light\n[00:49.70]Galaxies write our refrain\n[00:54.15]Love reborn in endless rain\n[01:15.30]Paper boats of memories\n[01:19.75]Float through veins of ancient trees\n[01:24.20]Your laughter spins aurora threads\n[01:28.65]Weaving dawn through featherbed"""], # ["""[00:04.27]只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n[01:36.35]你到底属于谁\n[01:37.65]就是现在告诉我\n[01:40.00]跟着那节奏 缓缓 make wave\n"""], # ["""[00:16.55]倦鸟西归 竹影余晖\n[00:23.58]禅意心扉\n[00:27.32]待清风 拂开一池春水\n[00:30.83]你的手绘 玉色难褪\n[00:37.99]我端详飘散的韵味\n[00:40.65]落款壶底的名讳\n[00:42.92]如吻西施的嘴\n[00:45.14]风雅几回 总相随\n[00:52.32]皆因你珍贵\n[00:57.85]三千弱水 煮一杯\n[01:02.21]我只饮下你的美\n[01:04.92]千年余味 紫砂壶伴我醉\n[01:09.73]酿一世无悔\n[01:12.09]沏壶春水 翠烟飞\n[01:16.62]把盏不尽你的香味\n[01:20.06]邀月相对 愿今生同宿同归\n[01:26.43]只让你陪\n[01:46.12]茗香芳菲 世俗无追\n"""] # ], # inputs=[lrc], # label="Lrc Examples", # examples_per_page=4, # elem_id="lrc-examples-container", # ) tabs.select( lambda s: None, None, None ) # TODO add max_frames parameter for infer_music lyrics_btn.click( fn=infer_music, inputs=[ lrc, current_prompt_type, audio_prompt, text_prompt, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, ], outputs=audio_output, ) # demo.queue().launch(show_api=False, show_error=True) if __name__ == "__main__": demo.launch()