Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from gradio_components.image import generate_caption | |
| from gradio_components.prediction import predict, transcribe | |
| theme = gr.themes.Glass( | |
| primary_hue="fuchsia", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| font=[ | |
| gr.themes.GoogleFont("Source Sans Pro"), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ], | |
| ).set( | |
| body_background_fill_dark="*background_fill_primary", | |
| embed_radius="*table_radius", | |
| background_fill_primary="*neutral_50", | |
| background_fill_primary_dark="*neutral_950", | |
| background_fill_secondary_dark="*neutral_900", | |
| border_color_accent="*neutral_600", | |
| border_color_accent_subdued="*color_accent", | |
| border_color_primary_dark="*neutral_700", | |
| block_background_fill="*background_fill_primary", | |
| block_background_fill_dark="*neutral_800", | |
| block_border_width="1px", | |
| block_label_background_fill="*background_fill_primary", | |
| block_label_background_fill_dark="*background_fill_secondary", | |
| block_label_text_color="*neutral_500", | |
| block_label_text_size="*text_sm", | |
| block_label_text_weight="400", | |
| block_shadow="none", | |
| block_shadow_dark="none", | |
| block_title_text_color="*neutral_500", | |
| block_title_text_weight="400", | |
| panel_border_width="0", | |
| panel_border_width_dark="0", | |
| checkbox_background_color_dark="*neutral_800", | |
| checkbox_border_width="*input_border_width", | |
| checkbox_label_border_width="*input_border_width", | |
| input_background_fill="*neutral_100", | |
| input_background_fill_dark="*neutral_700", | |
| input_border_color_focus_dark="*neutral_700", | |
| input_border_width="0px", | |
| input_border_width_dark="0px", | |
| slider_color="#2563eb", | |
| slider_color_dark="#2563eb", | |
| table_even_background_fill_dark="*neutral_950", | |
| table_odd_background_fill_dark="*neutral_900", | |
| button_border_width="*input_border_width", | |
| button_shadow_active="none", | |
| button_primary_background_fill="*primary_200", | |
| button_primary_background_fill_dark="*primary_700", | |
| button_primary_background_fill_hover="*button_primary_background_fill", | |
| button_primary_background_fill_hover_dark="*button_primary_background_fill", | |
| button_secondary_background_fill="*neutral_200", | |
| button_secondary_background_fill_dark="*neutral_600", | |
| button_secondary_background_fill_hover="*button_secondary_background_fill", | |
| button_secondary_background_fill_hover_dark="*button_secondary_background_fill", | |
| button_cancel_background_fill="*button_secondary_background_fill", | |
| button_cancel_background_fill_dark="*button_secondary_background_fill", | |
| button_cancel_background_fill_hover="*button_cancel_background_fill", | |
| button_cancel_background_fill_hover_dark="*button_cancel_background_fill", | |
| ) | |
| _AUDIOCRAFT_MODELS = [ | |
| "facebook/musicgen-melody", | |
| "facebook/musicgen-medium", | |
| "facebook/musicgen-small", | |
| "facebook/musicgen-large", | |
| "facebook/musicgen-melody-large", | |
| "facebook/audiogen-medium", | |
| ] | |
| def generate_prompt(difficulty, style): | |
| _DIFFICULTY_MAPPIN = { | |
| "Easy": "beginner player", | |
| "Medum": "player who has 2-3 years experience", | |
| "Hard": "player who has more than 4 years experiences", | |
| } | |
| prompt = "piano only music for a {} to pratice with the touch of {}".format( | |
| _DIFFICULTY_MAPPIN[difficulty], style | |
| ) | |
| return prompt | |
| def toggle_melody_condition(melody_condition): | |
| if melody_condition: | |
| return gr.Audio( | |
| sources=["microphone", "upload"], | |
| label="Record or upload your audio", | |
| show_label=True, | |
| visible=True, | |
| ) | |
| else: | |
| return gr.Audio( | |
| sources=["microphone", "upload"], | |
| label="Record or upload your audio", | |
| show_label=True, | |
| visible=False, | |
| ) | |
| def show_caption(show_caption_condition, description, prompt): | |
| if show_caption_condition: | |
| return ( | |
| gr.Textbox( | |
| label="Image Caption", | |
| value=description, | |
| interactive=False, | |
| show_label=True, | |
| visible=True, | |
| ), | |
| gr.Textbox( | |
| label="Generated Prompt", | |
| value=prompt, | |
| interactive=True, | |
| show_label=True, | |
| visible=True, | |
| ), | |
| gr.Button("Generate Music", interactive=True, visible=True), | |
| ) | |
| else: | |
| return ( | |
| gr.Textbox( | |
| label="Image Caption", | |
| value=description, | |
| interactive=False, | |
| show_label=True, | |
| visible=False, | |
| ), | |
| gr.Textbox( | |
| label="Generated Prompt", | |
| value=prompt, | |
| interactive=True, | |
| show_label=True, | |
| visible=False, | |
| ), | |
| gr.Button(label="Generate Music", interactive=True, visible=True), | |
| ) | |
| def post_submit(show_caption, model_path, image_input): | |
| _, description, prompt = generate_caption(image_input, model_path) | |
| return ( | |
| gr.Textbox( | |
| label="Image Caption", | |
| value=description, | |
| interactive=False, | |
| show_label=True, | |
| visible=show_caption, | |
| ), | |
| gr.Textbox( | |
| label="Generated Prompt", | |
| value=prompt, | |
| interactive=True, | |
| show_label=True, | |
| visible=show_caption, | |
| ), | |
| gr.Button("Generate Music", interactive=True, visible=True), | |
| ) | |
| def UI(): | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Generate Music by melody"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_path = gr.Dropdown( | |
| choices=_AUDIOCRAFT_MODELS, | |
| label="Select the model", | |
| value="facebook/musicgen-melody-large", | |
| ) | |
| with gr.Row(): | |
| duration = gr.Slider( | |
| minimum=10, | |
| maximum=60, | |
| value=10, | |
| label="Duration", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| topk = gr.Number(label="Top-k", value=250, interactive=True) | |
| topp = gr.Number(label="Top-p", value=0, interactive=True) | |
| temperature = gr.Number( | |
| label="Temperature", value=1.0, interactive=True | |
| ) | |
| sample_rate = gr.Number( | |
| label="output music sample rate", | |
| value=32000, | |
| interactive=True, | |
| ) | |
| difficulty = gr.Radio( | |
| ["Easy", "Medium", "Hard"], | |
| label="Difficulty", | |
| value="Easy", | |
| interactive=True, | |
| ) | |
| style = gr.Radio( | |
| ["Jazz", "Classical Music", "Hip Hop", "Others"], | |
| value="Classical Music", | |
| label="music genre", | |
| interactive=True, | |
| ) | |
| if style == "Others": | |
| style = gr.Textbox(label="Type your music genre") | |
| prompt = generate_prompt(difficulty.value, style.value) | |
| customize = gr.Checkbox( | |
| label="Customize the prompt", interactive=True | |
| ) | |
| if customize: | |
| prompt = gr.Textbox(label="Type your prompt") | |
| with gr.Column(): | |
| with gr.Row(): | |
| melody = gr.Audio( | |
| sources=["microphone", "upload"], | |
| label="Record or upload your audio", | |
| # interactive=True, | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Generate Music") | |
| output_audio = gr.Audio( | |
| "listen to the generated music", type="filepath" | |
| ) | |
| with gr.Row(): | |
| transcribe_button = gr.Button("Transcribe") | |
| d = gr.DownloadButton("Download the file", visible=False) | |
| transcribe_button.click( | |
| transcribe, inputs=[output_audio], outputs=d | |
| ) | |
| submit.click( | |
| fn=predict, | |
| inputs=[ | |
| model_path, | |
| prompt, | |
| melody, | |
| duration, | |
| topk, | |
| topp, | |
| temperature, | |
| sample_rate, | |
| ], | |
| outputs=output_audio, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/twinkle_twinkle_little_stars_mozart_20sec" | |
| ".mp3", | |
| ), | |
| "Easy", | |
| 32000, | |
| 20, | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/golden_hour_20sec.mp3", | |
| ), | |
| "Easy", | |
| 32000, | |
| 20, | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/turkish_march_mozart_20sec.mp3", | |
| ), | |
| "Easy", | |
| 32000, | |
| 20, | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/golden_hour_20sec.mp3", | |
| ), | |
| "Hard", | |
| 32000, | |
| 20, | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/golden_hour_20sec.mp3", | |
| ), | |
| "Hard", | |
| 32000, | |
| 40, | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/audio/golden_hour_20sec.mp3", | |
| ), | |
| "Hard", | |
| 16000, | |
| 20, | |
| ], | |
| ], | |
| inputs=[melody, difficulty, sample_rate, duration], | |
| label="Audio Examples", | |
| outputs=[output_audio], | |
| # cache_examples=True, | |
| ) | |
| with gr.Tab("Generate Music by image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image("Upload an image", type="filepath") | |
| melody_condition = gr.Checkbox( | |
| label="Generate music by melody", interactive=True, value=False | |
| ) | |
| melody = gr.Audio( | |
| sources=["microphone", "upload"], | |
| label="Record or upload your audio", | |
| show_label=True, | |
| visible=False, | |
| ) | |
| melody_condition.change( | |
| fn=toggle_melody_condition, | |
| inputs=[melody_condition], | |
| outputs=melody, | |
| ) | |
| description = gr.Textbox( | |
| label="Image Captioning", | |
| show_label=True, | |
| interactive=False, | |
| visible=False, | |
| ) | |
| prompt = gr.Textbox( | |
| label="Generated Prompt", | |
| show_label=True, | |
| interactive=True, | |
| visible=False, | |
| ) | |
| show_prompt = gr.Checkbox(label="Show the prompt", interactive=True) | |
| submit = gr.Button("submit", interactive=True, visible=True) | |
| generate = gr.Button( | |
| "Generate Music", interactive=True, visible=False | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_path = gr.Dropdown( | |
| choices=_AUDIOCRAFT_MODELS, | |
| label="Select the model", | |
| value="facebook/musicgen-large", | |
| ) | |
| with gr.Row(): | |
| duration = gr.Slider( | |
| minimum=10, | |
| maximum=60, | |
| value=10, | |
| label="Duration", | |
| interactive=True, | |
| ) | |
| topk = gr.Number(label="Top-k", value=250, interactive=True) | |
| topp = gr.Number(label="Top-p", value=0, interactive=True) | |
| temperature = gr.Number( | |
| label="Temperature", value=1.0, interactive=True | |
| ) | |
| sample_rate = gr.Number( | |
| label="output music sample rate", value=32000, interactive=True | |
| ) | |
| with gr.Column(): | |
| output_audio = gr.Audio( | |
| "listen to the generated music", | |
| type="filepath", | |
| show_label=True, | |
| ) | |
| transcribe_button = gr.Button("Transcribe") | |
| d = gr.DownloadButton("Download the file", visible=False) | |
| submit.click( | |
| fn=post_submit, | |
| inputs=[show_prompt, model_path, image_input], | |
| outputs=[description, prompt, generate], | |
| ) | |
| show_prompt.change( | |
| fn=show_caption, | |
| inputs=[show_prompt, description, prompt], | |
| outputs=[description, prompt, generate], | |
| ) | |
| transcribe_button.click(transcribe, inputs=[output_audio], outputs=d) | |
| generate.click( | |
| fn=predict, | |
| inputs=[ | |
| model_path, | |
| prompt, | |
| melody, | |
| duration, | |
| topk, | |
| topp, | |
| temperature, | |
| sample_rate, | |
| ], | |
| outputs=output_audio, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/image/kids_drawing.jpeg", | |
| ), | |
| False, | |
| None, | |
| "facebook/musicgen-large", | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/image/cat.jpeg", | |
| ), | |
| False, | |
| None, | |
| "facebook/musicgen-large", | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/image/cat.jpeg", | |
| ), | |
| True, | |
| "./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3", | |
| "facebook/musicgen-melody-large", | |
| ], | |
| [ | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| "./data/image/beach.jpeg", | |
| ), | |
| False, | |
| None, | |
| "facebook/audiogen-medium", | |
| ], | |
| ], | |
| inputs=[image_input, melody_condition, melody, model_path], | |
| label="Audio Examples", | |
| outputs=[output_audio], | |
| # cache_examples=True, | |
| ) | |
| demo.queue().launch() | |
| if __name__ == "__main__": | |
| UI() | |