ACloudCenter's picture
Add API endpoint and fix deprecation warnings
009c9f3
raw
history blame
3.95 kB
import argparse
from ui.components import create_main_demo_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os
import gradio as gr
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
persistent_storage_path = "./data"
def main(args):
model_demo = ACEStepPipeline(
checkpoint_dir=args.checkpoint_path,
dtype="bfloat16" if args.bf16 else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args.torch_compile
)
data_sampler = DataSampler()
# Create API function for external calls
def generate_music_api(
duration: float = 20.0,
tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
lyrics: str = "[instrumental]",
infer_steps: int = 60,
guidance_scale: float = 15.0,
):
"""
API function to generate music
Args:
duration: Duration in seconds (default 20)
tags: Music tags/style description
lyrics: Lyrics or [instrumental] for no vocals
infer_steps: Inference steps (default 60)
guidance_scale: Guidance scale (default 15.0)
Returns:
audio_path: Path to generated audio file
"""
result = model_demo(
audio_duration=duration,
prompt=tags,
lyrics=lyrics,
infer_step=infer_steps,
guidance_scale=guidance_scale,
scheduler_type="euler",
cfg_type="apg",
omega_scale=10.0,
manual_seeds=None,
guidance_interval=0.5,
guidance_interval_decay=0.0,
min_guidance_scale=3.0,
use_erg_tag=True,
use_erg_lyric=False,
use_erg_diffusion=True,
oss_steps=None,
guidance_scale_text=0.0,
guidance_scale_lyric=0.0,
audio2audio_enable=False,
ref_audio_strength=0.5,
ref_audio_input=None,
lora_name_or_path="none"
)
# Return the audio file path
if result and len(result) > 0:
return result[0] # Return first audio output
return None
demo = create_main_demo_ui(
text2music_process_func=model_demo.__call__,
sample_data_func=data_sampler.sample,
load_data_func=data_sampler.load_json,
)
# Add API endpoint to the demo
demo.api_open = True
demo.api_name = "/generate_music"
# Make the API function available
with demo:
gr.Interface(
fn=generate_music_api,
inputs=[
gr.Number(value=20, label="Duration (seconds)"),
gr.Textbox(value="edm, synth, bass, 128 bpm, energetic", label="Tags"),
gr.Textbox(value="[instrumental]", label="Lyrics"),
gr.Number(value=60, label="Inference Steps"),
gr.Number(value=15.0, label="Guidance Scale"),
],
outputs=gr.Audio(type="filepath", label="Generated Music"),
api_name="generate",
visible=False # Hide this interface, it's only for API
)
demo.queue(default_concurrency_limit=8).launch(
server_name=args.server_name,
server_port=args.port,
share=args.share
)
if __name__ == "__main__":
main(args)