Spaces:
Sleeping
Sleeping
| import time | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| import basic_pitch | |
| import basic_pitch.inference | |
| import gradio as gr | |
| import torch | |
| from audiocraft.data.audio import audio_write | |
| from audiocraft.data.audio_utils import convert_audio | |
| from audiocraft.models import AudioGen, MusicGen, MAGNeT | |
| from basic_pitch import ICASSP_2022_MODEL_PATH | |
| # from transformers import AutoModelForSeq2SeqLM | |
| from concurrent.futures import ProcessPoolExecutor | |
| import typing as tp | |
| import warnings | |
| import json | |
| import ast | |
| import torchaudio | |
| MODEL = None | |
| def load_model(version='facebook/musicgen-large'): | |
| global MODEL | |
| if MODEL is None or MODEL.name != version: | |
| del MODEL | |
| MODEL = None # in case loading would crash | |
| print("Loading model", version) | |
| if "magnet" in version: | |
| MODEL = MAGNeT.get_pretrained(version) | |
| elif "musicgen" in version: | |
| MODEL = MusicGen.get_pretrained(version) | |
| elif "musiclang" in version: | |
| # TODO: Implement MusicLang | |
| pass | |
| elif "audiogen" in version: | |
| MODEL = AudioGen.get_pretrained(version) | |
| else: | |
| raise ValueError("Invalid model version") | |
| return MODEL | |
| pool = ProcessPoolExecutor(4) | |
| class FileCleaner: | |
| def __init__(self, file_lifetime: float = 3600): | |
| self.file_lifetime = file_lifetime | |
| self.files = [] | |
| def add(self, path: tp.Union[str, Path]): | |
| self._cleanup() | |
| self.files.append((time.time(), Path(path))) | |
| def _cleanup(self): | |
| now = time.time() | |
| for time_added, path in list(self.files): | |
| if now - time_added > self.file_lifetime: | |
| if path.exists(): | |
| path.unlink() | |
| self.files.pop(0) | |
| else: | |
| break | |
| file_cleaner = FileCleaner() | |
| def inference_musicgen_text_to_music(model, configs, text, num_outputs=1): | |
| model.set_generation_params( | |
| **configs | |
| ) | |
| descriptions = [text for _ in range(num_outputs)] | |
| output = model.generate(descriptions=descriptions ,progress=True, return_tokens=False) | |
| return output | |
| def inference_musicgen_continuation(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1): | |
| model.set_generation_params( | |
| **configs | |
| ) | |
| # melody, prompt_sr = torchaudio.load(prompt_waveform) | |
| # descriptions = [text for _ in range(num_outputs)] | |
| # prompt = [prompt_waveform for _ in range(num_outputs)] | |
| output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=False) | |
| return output | |
| def inference_musicgen_melody_condition(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1): | |
| model.set_generation_params(**configs) | |
| descriptions = [text for _ in range(num_outputs)] | |
| output = model.generate_with_chroma( | |
| descriptions=descriptions, | |
| melody_wavs=prompt_waveform, | |
| melody_sample_rate=prompt_sr, | |
| progress=True, | |
| return_tokens=False | |
| ) | |
| return output | |
| def inference_magnet(model, configs, text, num_outputs=1): | |
| model.set_generation_params( | |
| **configs | |
| ) | |
| descriptions = [text for _ in range(num_outputs)] | |
| output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
| return output | |
| def inference_magnet_audio(model, configs, text, num_outputs=1): | |
| model.set_generation_params( | |
| **configs | |
| ) | |
| descriptions = [text for _ in range(num_outputs)] | |
| output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
| return output | |
| def inference_audiogen(model, configs, text, num_outputs=1): | |
| model.set_generation_params( | |
| **configs | |
| ) | |
| descriptions = [text for _ in range(num_outputs)] | |
| output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
| return output | |
| def inference_musiclang(): | |
| # TODO: Implement MusicLang | |
| pass | |
| def process_audio(gr_audio, prompt_duration, model): | |
| # audio, sr = torch.from_numpy(gr_audio[1]).to(model.device).float().t(), gr_audio[0] | |
| audio, sr = torchaudio.load(gr_audio) | |
| audio = audio[..., :int(prompt_duration * sr)] | |
| return audio, sr | |
| _MODEL_INFERENCES = { | |
| "facebook/musicgen-small": inference_musicgen_text_to_music, | |
| "facebook/musicgen-medium": inference_musicgen_text_to_music, | |
| "facebook/musicgen-large": inference_musicgen_text_to_music, | |
| "facebook/musicgen-melody": inference_musicgen_melody_condition, | |
| "facebook/musicgen-melody-large": inference_musicgen_melody_condition, | |
| "facebook/magnet-small-10secs": inference_magnet, | |
| "facebook/magnet-medium-10secs": inference_magnet, | |
| "facebook/magnet-small-30secs": inference_magnet, | |
| "facebook/magnet-medium-30secs": inference_magnet, | |
| "facebook/audio-magnet-small": inference_magnet_audio, | |
| "facebook/audio-magnet-medium": inference_magnet_audio, | |
| "facebook/audiogen-medium": inference_audiogen, | |
| "musicgen-continuation": inference_musicgen_continuation, | |
| } | |
| def _do_predictions( | |
| model_file, | |
| model, | |
| text, | |
| melody = None, | |
| mel_sample_rate=None, | |
| progress=False, | |
| num_generations=1, | |
| **gen_kwargs, | |
| ): | |
| print( | |
| "new generation", | |
| text, | |
| None if melody is None else melody.shape | |
| ) | |
| be = time.time() | |
| try: | |
| if melody is not None: | |
| # melody condition or continuation | |
| if 'melody' in model_file: | |
| # melody condition - musicgen-melody, musicgen-melody-large | |
| inderence_func = _MODEL_INFERENCES[model_file] | |
| else: | |
| # melody continuation | |
| inderence_func = _MODEL_INFERENCES['musicgen-continuation'] | |
| outputs = inderence_func(model, gen_kwargs, text, melody, mel_sample_rate, num_generations) | |
| else: | |
| # text-to-music, text-to-sound | |
| inderence_func = _MODEL_INFERENCES[model_file] | |
| outputs = inderence_func(model, gen_kwargs, text, num_generations) | |
| except RuntimeError as e: | |
| raise gr.Error("Error while generating " + e.args[0]) | |
| outputs = outputs.detach().cpu().float() | |
| out_audios = [] | |
| video_processes = [] | |
| for output in outputs: | |
| with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
| audio_write( | |
| file.name, | |
| output, | |
| model.sample_rate, | |
| strategy="loudness", | |
| loudness_headroom_db=16, | |
| loudness_compressor=True, | |
| add_suffix=False, | |
| ) | |
| # video_processes.append(pool.submit(make_waveform, file.name)) | |
| out_audios.append(file.name) | |
| file_cleaner.add(file.name) | |
| # out_videos = [video.result() for video in video_processes] | |
| # for video in out_videos: | |
| # file_cleaner.add(video) | |
| print("generation finished", len(outputs), time.time() - be) | |
| return out_audios | |
| def make_waveform(*args, **kwargs): | |
| # Further remove some warnings. | |
| be = time.time() | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter('ignore') | |
| out = gr.make_waveform(*args, **kwargs) | |
| print("Make a video took", time.time() - be) | |
| return out | |
| def predict( | |
| model_version, | |
| generation_configs, | |
| prompt_text=None, | |
| prompt_wav=None, | |
| num_generations=1, | |
| progress=gr.Progress(), | |
| ): | |
| global INTERRUPTING | |
| INTERRUPTING = False | |
| progress(0, desc="Loading model...") | |
| def _progress(generated, to_generate): | |
| nonlocal max_generated | |
| max_generated = max(generated, max_generated) | |
| progress((min(max_generated, to_generate), to_generate)) | |
| if INTERRUPTING: | |
| raise gr.Error("Interrupted.") | |
| model = load_model(model_version) | |
| model.set_custom_progress_callback(_progress) | |
| if isinstance(generation_configs, str): | |
| generation_configs = ast.literal_eval(generation_configs) | |
| max_generated = 0 | |
| if prompt_wav is not None: | |
| melody, mel_sample_rate = process_audio(prompt_wav, generation_configs['duration'], model) | |
| else: | |
| melody, mel_sample_rate = None, None | |
| audios = _do_predictions( | |
| model_version, | |
| model, | |
| prompt_text, | |
| melody, | |
| mel_sample_rate, | |
| progress=True, | |
| num_generations = num_generations, | |
| **generation_configs, | |
| ) | |
| return audios | |
| def transcribe(audio_path): | |
| """ | |
| Transcribe an audio file to MIDI using the basic_pitch model. | |
| """ | |
| # model_output, midi_data, note_events = predict("generated_0.wav") | |
| tmp_paths = ast.literal_eval(audio_path) | |
| download_buttons = [] | |
| for audio_path in tmp_paths: | |
| model_output, midi_data, note_events = basic_pitch.inference.predict( | |
| audio_path=audio_path, | |
| model_or_model_path=ICASSP_2022_MODEL_PATH, | |
| ) | |
| with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file: | |
| try: | |
| midi_data.write(file) | |
| print(f"midi file saved to {file.name}") | |
| except Exception as e: | |
| print(f"Error while writing midi file: {e}") | |
| raise e | |
| download_buttons.append(gr.DownloadButton( | |
| value=file.name, label=f"Download MIDI file {file.name}", visible=True | |
| )) | |
| file_cleaner.add(file.name) | |
| return download_buttons | |