Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,12 +25,7 @@ from TTS.api import TTS
|
|
| 25 |
# Diffusers for sound design generation
|
| 26 |
from diffusers import DiffusionPipeline, AudioLDMPipeline
|
| 27 |
import diffusers
|
| 28 |
-
|
| 29 |
-
# Monkey-patch: Create a patched pipeline class so that any reference to AudioLDM2Pipeline is resolved correctly.
|
| 30 |
-
class PatchedAudioLDM2Pipeline(AudioLDMPipeline):
|
| 31 |
-
pass
|
| 32 |
-
|
| 33 |
-
setattr(diffusers, "AudioLDM2Pipeline", PatchedAudioLDM2Pipeline)
|
| 34 |
|
| 35 |
# ---------------------------------------------------------------------
|
| 36 |
# Setup Logging and Environment Variables
|
|
@@ -107,11 +102,16 @@ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
|
|
| 107 |
def get_sound_design_pipeline(model_name: str, token: str):
|
| 108 |
"""
|
| 109 |
Returns a cached DiffusionPipeline for sound design if available;
|
| 110 |
-
otherwise, it loads and caches the pipeline
|
|
|
|
|
|
|
|
|
|
| 111 |
"""
|
|
|
|
|
|
|
| 112 |
if model_name in SOUND_DESIGN_PIPELINES:
|
| 113 |
return SOUND_DESIGN_PIPELINES[model_name]
|
| 114 |
-
pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=
|
| 115 |
SOUND_DESIGN_PIPELINES[model_name] = pipe
|
| 116 |
return pipe
|
| 117 |
|
|
|
|
| 25 |
# Diffusers for sound design generation
|
| 26 |
from diffusers import DiffusionPipeline, AudioLDMPipeline
|
| 27 |
import diffusers
|
| 28 |
+
from packaging import version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ---------------------------------------------------------------------
|
| 31 |
# Setup Logging and Environment Variables
|
|
|
|
| 102 |
def get_sound_design_pipeline(model_name: str, token: str):
|
| 103 |
"""
|
| 104 |
Returns a cached DiffusionPipeline for sound design if available;
|
| 105 |
+
otherwise, it loads and caches the pipeline.
|
| 106 |
+
|
| 107 |
+
NOTE: AudioLDM2Pipeline is available only in diffusers>=0.21.0.
|
| 108 |
+
Since your requirements fix diffusers==0.20.2, this function will raise an error.
|
| 109 |
"""
|
| 110 |
+
if version.parse(diffusers.__version__) < version.parse("0.21.0"):
|
| 111 |
+
raise ValueError("AudioLDM2 requires diffusers>=0.21.0. Please upgrade your diffusers package.")
|
| 112 |
if model_name in SOUND_DESIGN_PIPELINES:
|
| 113 |
return SOUND_DESIGN_PIPELINES[model_name]
|
| 114 |
+
pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=AudioLDMPipeline, use_auth_token=token)
|
| 115 |
SOUND_DESIGN_PIPELINES[model_name] = pipe
|
| 116 |
return pipe
|
| 117 |
|