ACloudCenter's picture
simplify the UI further
4c3670f
import argparse
from ui.simple_components import create_simple_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os
import gradio as gr
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
persistent_storage_path = "./data"
def main(args):
model_demo = ACEStepPipeline(
checkpoint_dir=args.checkpoint_path,
dtype="bfloat16" if args.bf16 else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args.torch_compile
)
data_sampler = DataSampler()
# Create API function for external calls
def generate_music_api(
duration: float = 20.0,
tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
lyrics: str = "[instrumental]",
infer_steps: int = 60,
guidance_scale: float = 15.0,
):
"""
API function to generate music
Args:
duration: Duration in seconds (default 20)
tags: Music tags/style description
lyrics: Lyrics or [instrumental] for no vocals
infer_steps: Inference steps (default 60)
guidance_scale: Guidance scale (default 15.0)
Returns:
audio_path: Path to generated audio file
"""
result = model_demo(
audio_duration=duration,
prompt=tags,
lyrics=lyrics,
infer_step=infer_steps,
guidance_scale=guidance_scale,
scheduler_type="euler",
cfg_type="apg",
omega_scale=10.0,
manual_seeds=None,
guidance_interval=0.5,
guidance_interval_decay=0.0,
min_guidance_scale=3.0,
use_erg_tag=True,
use_erg_lyric=False,
use_erg_diffusion=True,
oss_steps=None,
guidance_scale_text=0.0,
guidance_scale_lyric=0.0,
audio2audio_enable=False,
ref_audio_strength=0.5,
ref_audio_input=None,
lora_name_or_path="none"
)
# Return the audio file path
if result and len(result) > 0:
return result[0] # Return first audio output (now always 24kHz WAV)
return None
# Use simplified UI
demo = create_simple_ui(
text2music_process_func=model_demo.__call__
)
# Add API endpoint to the demo
demo.api_open = True
demo.queue(default_concurrency_limit=8).launch(
server_name=args.server_name,
server_port=args.port,
share=args.share
)
if __name__ == "__main__":
main(args)