Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -207,9 +207,12 @@ def generate_music(prompt: str, audio_length: int):
|
|
| 207 |
model_key = "facebook/musicgen-large"
|
| 208 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
| 209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 210 |
-
|
|
|
|
|
|
|
| 211 |
with torch.inference_mode():
|
| 212 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
|
|
|
| 213 |
audio_data = outputs[0, 0].cpu().numpy()
|
| 214 |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
|
| 215 |
output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
|
|
|
|
| 207 |
model_key = "facebook/musicgen-large"
|
| 208 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
| 209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 210 |
+
# Process the input and move each tensor to the proper device
|
| 211 |
+
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
|
| 212 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 213 |
with torch.inference_mode():
|
| 214 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
| 215 |
+
# Post-process the output to create a WAV file
|
| 216 |
audio_data = outputs[0, 0].cpu().numpy()
|
| 217 |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
|
| 218 |
output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
|