| import argparse | |
| import logging | |
| import os | |
| from typing import Optional | |
| from core.bark.generate_audio_semantic_dataset import ( | |
| generate_wav_semantic_dataset, | |
| BarkGenerationConfig, | |
| ) | |
| from core.utils import upload_file_to_hf, zip_folder | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def parse_dataset_args(args_list=None): | |
| """Parse arguments specific to dataset creation.""" | |
| parser = argparse.ArgumentParser(description="Audio Semantic Dataset Creation") | |
| parser.add_argument( | |
| "--text-file", | |
| type=str, | |
| default="data/test_data.txt", | |
| help="Path to text file for dataset generation", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=2, | |
| help="Batch size for processing (default: 1)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default="./dataset", | |
| help="Output directory for generated files (default: ./dataset)", | |
| ) | |
| parser.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=256, | |
| help="Maximum tokens per example (default: 256)", | |
| ) | |
| parser.add_argument( | |
| "--use-small-model", | |
| action="store_true", | |
| help="Use small model for generation", | |
| ) | |
| parser.add_argument( | |
| "--save-raw-audio", | |
| action="store_true", | |
| help="Store generated audio as .wav instead of .npz", | |
| ) | |
| parser.add_argument( | |
| "--publish-hf", | |
| action="store_true", | |
| help="Publish dataset to HuggingFace Hub", | |
| ) | |
| parser.add_argument( | |
| "--repo-id", | |
| type=str, | |
| help="HuggingFace repo ID to publish to", | |
| ) | |
| parser.add_argument( | |
| "--path-in-repo", | |
| type=str, | |
| help="Path in HF repo", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--silent", action="store_true", help="Suppress progress output" | |
| ) | |
| return parser.parse_args(args_list) | |
| def create_audio_semantic_dataset( | |
| text_file: str, | |
| output_dir: str = "./dataset", | |
| batch_size: int = 1, | |
| max_tokens: int = 256, | |
| use_small_model: bool = False, | |
| save_raw_audio: bool = False, | |
| publish_hf: bool = False, | |
| repo_id: Optional[str] = None, | |
| path_in_repo: Optional[str] = None, | |
| silent: bool = False, | |
| ) -> None: | |
| """Create audio semantic dataset from text file. | |
| Can be called directly with parameters or via command line using parse_dataset_args(). | |
| Args: | |
| text_file: Path to input text file | |
| output_dir: Directory to save generated dataset | |
| batch_size: Batch size for processing | |
| max_tokens: Maximum tokens per example | |
| use_small_model: Whether to use small model | |
| save_raw_audio: Save as raw audio (.wav) instead of .npz | |
| publish_hf: Whether to publish to HuggingFace Hub | |
| repo_id: HF repo ID to publish to | |
| path_in_repo: Path in HF repo | |
| silent: Suppress progress output | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| if not os.path.isfile(text_file): | |
| raise FileNotFoundError(f"Text file not found: {text_file}") | |
| logger.info(f"Starting dataset generation from {text_file}") | |
| generation_config = BarkGenerationConfig( | |
| temperature=None, | |
| generate_coarse_temperature=None, | |
| generate_fine_temperature=None, | |
| use_small_model=use_small_model, | |
| ) | |
| generate_wav_semantic_dataset( | |
| text_file_path=text_file, | |
| generation_config=generation_config, | |
| batch_size=batch_size, | |
| save_path=output_dir, | |
| save_data_as_raw_audio=save_raw_audio, | |
| silent=silent, | |
| ) | |
| logger.info("Dataset generation completed") | |
| if publish_hf and repo_id: | |
| logger.info("Publishing dataset to huggingface hub") | |
| zip_path = "./dataset.zip" | |
| success = zip_folder(output_dir, zip_path) | |
| if not success: | |
| raise RuntimeError(f"Unable to zip folder {output_dir}") | |
| upload_file_to_hf(zip_path, repo_id, "dataset", path_in_repo=path_in_repo) | |
| if __name__ == "__main__": | |
| args = parse_dataset_args() | |
| create_audio_semantic_dataset( | |
| text_file=args.text_file, | |
| output_dir=args.output_dir, | |
| batch_size=args.batch_size, | |
| max_tokens=args.max_tokens, | |
| use_small_model=args.use_small_model, | |
| save_raw_audio=args.save_raw_audio, | |
| publish_hf=args.publish_hf, | |
| repo_id=args.repo_id, | |
| path_in_repo=args.path_in_repo, | |
| silent=args.silent, | |
| ) | |