Spaces:
Runtime error
Runtime error
| from typing import Literal | |
| from tempfile import NamedTemporaryFile | |
| from pathlib import Path | |
| import uuid | |
| import shutil | |
| import json | |
| import asyncio | |
| import toml | |
| import torch | |
| from audiocraft.models import MusicGen | |
| from audiocraft.data.audio import audio_write | |
| from pydub import AudioSegment | |
| from .utils import set_all_seeds | |
| from modules.llms import get_llm_factory | |
| class MusicMaker: | |
| # TODO: DocString... | |
| """Class for generating music from prompts.""" | |
| def __init__(self, model_size: Literal['small', 'medium', 'melody', 'large'] = 'large', | |
| output_format: Literal['wav', 'mp3'] = 'mp3', | |
| device: str = None) -> None: | |
| """Initialize the MusicMaker class. | |
| Args: | |
| model_size (Literal['small', 'medium', 'melody', 'large'], optional): Model size. Defaults to 'large'. | |
| output_format (Literal['wav', 'mp3'], optional): Output format. Defaults to 'mp3'. | |
| device (str, optional): Device to use for the model. Defaults to None. | |
| """ | |
| self.__model_size = model_size | |
| self.__output_format = output_format | |
| self.__device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not device else device | |
| print("Loading the MusicGen model into memory...") | |
| self.__mg_model = MusicGen.get_pretrained(self.model_size, device=self.device) | |
| self.__mg_model.set_generation_params(use_sampling=True, | |
| top_k=250, | |
| top_p=0.0, | |
| temperature=1.0, | |
| cfg_coef=3.0 | |
| ) | |
| output_dir = Path('.') / 'outputs' | |
| if not output_dir.exists(): | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| elif output_dir.is_file(): | |
| assert False, f"A file with the same name as the desired directory ('{str(output_dir)}') already exists." | |
| def text2music(self, prompt: str, length: int = 60, seed: int = None) -> str: | |
| """Generate a music from the prompt. | |
| Args: | |
| prompt (str): Prompt to generate the music from. | |
| length (int, optional): Length of the music in seconds. Defaults to 60. | |
| seed (int, optional): Seed to use for the generation. Defaults to None. | |
| Returns: | |
| str: Path to the generated music. | |
| """ | |
| def wavToMp3(src_file: str, dest_file: str) -> None: | |
| sound = AudioSegment.from_wav(src_file) | |
| sound.export(dest_file, format="mp3") | |
| output_filename = Path('.') / 'outputs' / str(uuid.uuid4()) | |
| if not seed or seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| set_all_seeds(seed) | |
| self.__mg_model.set_generation_params(duration=length) | |
| output = self.__mg_model.generate(descriptions=[prompt], progress=True)[0] | |
| with NamedTemporaryFile("wb", delete=True) as temp_file: | |
| audio_write(temp_file.name, output.cpu(), self.__mg_model.sample_rate, strategy="loudness", loudness_compressor=True) | |
| if self.output_format == 'mp3': | |
| wavToMp3(f'{temp_file.name}.wav', str(output_filename.with_suffix('.mp3'))) | |
| else: | |
| shutil.copy(f'{temp_file.name}.wav', str(output_filename.with_suffix('.wav'))) | |
| return str(output_filename.with_suffix('.mp3' if self.output_format == 'mp3' else '.wav')) | |
| def generate_prompt(self, genre:str, place:str, mood:str, | |
| title:str, chapter_title:str, chapter_plot:str, | |
| llm_type:str="PaLM", | |
| ) -> str: | |
| """Generate a prompt for a background music based on given attributes. | |
| Args: | |
| genre (str): Genre of the story. | |
| place (str): Place of the story. | |
| mood (str): Mood of the story. | |
| title (str): Title of the story. | |
| chapter_title (str): Title of the chapter. | |
| chapter_plot (str): Plot of the chapter. | |
| Returns: | |
| str: Generated prompt. | |
| """ | |
| factory = get_llm_factory(llm_type) | |
| prompt_manager = factory.create_prompt_manager() | |
| llm_service = factory.create_llm_service() | |
| # Generate prompts with PaLM | |
| t = prompt_manager.prompts['music_gen']['gen_prompt'] | |
| q = prompt_manager.prompts['music_gen']['query'] | |
| query_string = t.format(input=q.format(genre=genre, | |
| place=place, | |
| mood=mood, | |
| title=title, | |
| chapter_title=chapter_title, | |
| chapter_plot=chapter_plot)) | |
| try: | |
| response, response_txt = asyncio.run(asyncio.wait_for( | |
| llm_service.gen_text(query_string, mode="text", use_filter=False), | |
| timeout=10) | |
| ) | |
| except asyncio.TimeoutError: | |
| raise TimeoutError("The response time for PaLM API exceeded the limit.") | |
| except: | |
| raise Exception("PaLM API is not available.") | |
| try: | |
| res_json = json.loads(response_txt) | |
| except: | |
| print("=== PaLM Response ===") | |
| print(response.filters) | |
| print(response_txt) | |
| print("=== PaLM Response ===") | |
| raise ValueError("The response from PaLM API is not in the expected format.") | |
| return res_json['primary_sentence'] | |
| def model_size(self): | |
| """Model size | |
| Returns: | |
| Literal['small', 'medium', 'melody', 'large']: The model size (read-only) | |
| """ | |
| return self.__model_size | |
| def output_format(self): | |
| """Output format | |
| Returns: | |
| Literal['wav', 'mp3']: The output format (read-only) | |
| """ | |
| return self.__output_format | |
| def device(self): | |
| """Device | |
| Returns: | |
| str: The device (read-only) | |
| """ | |
| return self.__device | |