Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,185 Bytes
0fd05a1 2c1e8f3 0fd05a1 009c9f3 0fd05a1 50a3943 0fd05a1 009c9f3 288fb45 009c9f3 0fd05a1 2c1e8f3 0fd05a1 009c9f3 0fd05a1 50a3943 0fd05a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import argparse
from ui.simple_components import create_simple_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os
import gradio as gr
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
persistent_storage_path = "./data"
def main(args):
model_demo = ACEStepPipeline(
checkpoint_dir=args.checkpoint_path,
dtype="bfloat16" if args.bf16 else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args.torch_compile
)
data_sampler = DataSampler()
# Create API function for external calls
def generate_music_api(
duration: float = 20.0,
tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
lyrics: str = "[instrumental]",
infer_steps: int = 60,
guidance_scale: float = 15.0,
):
"""
API function to generate music
Args:
duration: Duration in seconds (default 20)
tags: Music tags/style description
lyrics: Lyrics or [instrumental] for no vocals
infer_steps: Inference steps (default 60)
guidance_scale: Guidance scale (default 15.0)
Returns:
audio_path: Path to generated audio file
"""
result = model_demo(
audio_duration=duration,
prompt=tags,
lyrics=lyrics,
infer_step=infer_steps,
guidance_scale=guidance_scale,
scheduler_type="euler",
cfg_type="apg",
omega_scale=10.0,
manual_seeds=None,
guidance_interval=0.5,
guidance_interval_decay=0.0,
min_guidance_scale=3.0,
use_erg_tag=True,
use_erg_lyric=False,
use_erg_diffusion=True,
oss_steps=None,
guidance_scale_text=0.0,
guidance_scale_lyric=0.0,
audio2audio_enable=False,
ref_audio_strength=0.5,
ref_audio_input=None,
lora_name_or_path="none"
)
# Return the audio file path
if result and len(result) > 0:
return result[0] # Return first audio output (now always 24kHz WAV)
return None
# Use simplified UI
demo = create_simple_ui(
text2music_process_func=model_demo.__call__
)
# Add API endpoint to the demo
demo.api_open = True
demo.queue(default_concurrency_limit=8).launch(
server_name=args.server_name,
server_port=args.port,
share=args.share
)
if __name__ == "__main__":
main(args)
|