Spaces:
Sleeping
Sleeping
include audio gen
Browse files
gradio_components/prediction.py
CHANGED
|
@@ -8,13 +8,16 @@ import gradio as gr
|
|
| 8 |
import torch
|
| 9 |
from audiocraft.data.audio import audio_write
|
| 10 |
from audiocraft.data.audio_utils import convert_audio
|
| 11 |
-
from audiocraft.models import MusicGen
|
| 12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
| 13 |
from transformers import AutoModelForSeq2SeqLM
|
| 14 |
|
| 15 |
|
| 16 |
def load_model(version="facebook/musicgen-melody"):
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def _do_predictions(
|
|
|
|
| 8 |
import torch
|
| 9 |
from audiocraft.data.audio import audio_write
|
| 10 |
from audiocraft.data.audio_utils import convert_audio
|
| 11 |
+
from audiocraft.models import MusicGen, AudioGen
|
| 12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
| 13 |
from transformers import AutoModelForSeq2SeqLM
|
| 14 |
|
| 15 |
|
| 16 |
def load_model(version="facebook/musicgen-melody"):
|
| 17 |
+
if version in ["facebook/audiogen-medium"]:
|
| 18 |
+
return AudioGen.get_pretrained(version)
|
| 19 |
+
else:
|
| 20 |
+
return MusicGen.get_pretrained(version)
|
| 21 |
|
| 22 |
|
| 23 |
def _do_predictions(
|