Spaces:
Sleeping
Sleeping
| import time | |
| import torch | |
| from audiocraft.data.audio_utils import convert_audio | |
| from audiocraft.data.audio import audio_write | |
| import gradio as gr | |
| from audiocraft.models import MusicGen | |
| from tempfile import NamedTemporaryFile | |
| from pathlib import Path | |
| def load_model(version='facebook/musicgen-melody'): | |
| return MusicGen.get_pretrained(version) | |
| def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs): | |
| print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) | |
| be = time.time() | |
| processed_melodies = [] | |
| for melody in melodies: | |
| if melody is None: | |
| processed_melodies.append(None) | |
| else: | |
| sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t() | |
| if melody.dim() == 1: | |
| melody = melody[None] | |
| melody = melody[..., :int(sr * duration)] | |
| melody = convert_audio(melody, sr, target_sr, target_ac) | |
| processed_melodies.append(melody) | |
| try: | |
| if any(m is not None for m in processed_melodies): | |
| # melody condition | |
| outputs = model.generate_with_chroma( | |
| descriptions=texts, | |
| melody_wavs=processed_melodies, | |
| melody_sample_rate=target_sr, | |
| progress=progress, | |
| return_tokens=False | |
| ) | |
| else: | |
| # text only | |
| outputs = model.generate(texts, progress=progress, return_tokens=False) | |
| except RuntimeError as e: | |
| raise gr.Error("Error while generating " + e.args[0]) | |
| outputs = outputs.detach().cpu().float() | |
| pending_videos = [] | |
| out_wavs = [] | |
| for output in outputs: | |
| with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
| audio_write( | |
| file.name, output, model.sample_rate, strategy="loudness", | |
| loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
| out_wavs.append(file.name) | |
| print("generation finished", len(texts), time.time() - be) | |
| return out_wavs | |
| def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()): | |
| global INTERRUPTING | |
| global USE_DIFFUSION | |
| INTERRUPTING = False | |
| progress(0, desc="Loading model...") | |
| model_path = model_path.strip() | |
| if model_path: | |
| if not Path(model_path).exists(): | |
| raise gr.Error(f"Model path {model_path} doesn't exist.") | |
| if not Path(model_path).is_dir(): | |
| raise gr.Error(f"Model path {model_path} must be a folder containing " | |
| "state_dict.bin and compression_state_dict_.bin.") | |
| if temperature < 0: | |
| raise gr.Error("Temperature must be >= 0.") | |
| if topk < 0: | |
| raise gr.Error("Topk must be non-negative.") | |
| if topp < 0: | |
| raise gr.Error("Topp must be non-negative.") | |
| topk = int(topk) | |
| model = load_model(model_path) | |
| max_generated = 0 | |
| def _progress(generated, to_generate): | |
| nonlocal max_generated | |
| max_generated = max(generated, max_generated) | |
| progress((min(max_generated, to_generate), to_generate)) | |
| if INTERRUPTING: | |
| raise gr.Error("Interrupted.") | |
| model.set_custom_progress_callback(_progress) | |
| wavs = _do_predictions( | |
| [text], | |
| [melody], | |
| duration, | |
| progress=True, | |
| target_ac=1, | |
| target_sr=target_sr, | |
| top_k=topk, | |
| top_p=topp, | |
| temperature=temperature, | |
| gradio_progress=progress) | |
| return wavs |