Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from ui.simple_components import create_simple_ui | |
| from pipeline_ace_step import ACEStepPipeline | |
| from data_sampler import DataSampler | |
| import os | |
| import gradio as gr | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint_path", type=str, default=None) | |
| parser.add_argument("--server_name", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--device_id", type=int, default=0) | |
| parser.add_argument("--share", action='store_true', default=False) | |
| parser.add_argument("--bf16", action='store_true', default=True) | |
| parser.add_argument("--torch_compile", type=bool, default=False) | |
| args = parser.parse_args() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) | |
| persistent_storage_path = "./data" | |
| def main(args): | |
| model_demo = ACEStepPipeline( | |
| checkpoint_dir=args.checkpoint_path, | |
| dtype="bfloat16" if args.bf16 else "float32", | |
| persistent_storage_path=persistent_storage_path, | |
| torch_compile=args.torch_compile | |
| ) | |
| data_sampler = DataSampler() | |
| # Create API function for external calls | |
| def generate_music_api( | |
| duration: float = 20.0, | |
| tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental", | |
| lyrics: str = "[instrumental]", | |
| infer_steps: int = 60, | |
| guidance_scale: float = 15.0, | |
| ): | |
| """ | |
| API function to generate music | |
| Args: | |
| duration: Duration in seconds (default 20) | |
| tags: Music tags/style description | |
| lyrics: Lyrics or [instrumental] for no vocals | |
| infer_steps: Inference steps (default 60) | |
| guidance_scale: Guidance scale (default 15.0) | |
| Returns: | |
| audio_path: Path to generated audio file | |
| """ | |
| result = model_demo( | |
| audio_duration=duration, | |
| prompt=tags, | |
| lyrics=lyrics, | |
| infer_step=infer_steps, | |
| guidance_scale=guidance_scale, | |
| scheduler_type="euler", | |
| cfg_type="apg", | |
| omega_scale=10.0, | |
| manual_seeds=None, | |
| guidance_interval=0.5, | |
| guidance_interval_decay=0.0, | |
| min_guidance_scale=3.0, | |
| use_erg_tag=True, | |
| use_erg_lyric=False, | |
| use_erg_diffusion=True, | |
| oss_steps=None, | |
| guidance_scale_text=0.0, | |
| guidance_scale_lyric=0.0, | |
| audio2audio_enable=False, | |
| ref_audio_strength=0.5, | |
| ref_audio_input=None, | |
| lora_name_or_path="none" | |
| ) | |
| # Return the audio file path | |
| if result and len(result) > 0: | |
| return result[0] # Return first audio output (now always 24kHz WAV) | |
| return None | |
| # Use simplified UI | |
| demo = create_simple_ui( | |
| text2music_process_func=model_demo.__call__ | |
| ) | |
| # Add API endpoint to the demo | |
| demo.api_open = True | |
| demo.queue(default_concurrency_limit=8).launch( | |
| server_name=args.server_name, | |
| server_port=args.port, | |
| share=args.share | |
| ) | |
| if __name__ == "__main__": | |
| main(args) | |