File size: 3,185 Bytes
0fd05a1
2c1e8f3
0fd05a1
 
 
009c9f3
0fd05a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50a3943
0fd05a1
 
 
 
 
 
 
 
 
 
 
009c9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288fb45
009c9f3
0fd05a1
2c1e8f3
 
 
0fd05a1
009c9f3
 
 
 
0fd05a1
50a3943
 
 
0fd05a1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
from ui.simple_components import create_simple_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os
import gradio as gr


parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)


persistent_storage_path = "./data"


def main(args):

    model_demo = ACEStepPipeline(
        checkpoint_dir=args.checkpoint_path,
        dtype="bfloat16" if args.bf16 else "float32",
        persistent_storage_path=persistent_storage_path,
        torch_compile=args.torch_compile
    )
    data_sampler = DataSampler()
    
    # Create API function for external calls
    def generate_music_api(
        duration: float = 20.0,
        tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
        lyrics: str = "[instrumental]",
        infer_steps: int = 60,
        guidance_scale: float = 15.0,
    ):
        """
        API function to generate music
        
        Args:
            duration: Duration in seconds (default 20)
            tags: Music tags/style description
            lyrics: Lyrics or [instrumental] for no vocals
            infer_steps: Inference steps (default 60)
            guidance_scale: Guidance scale (default 15.0)
            
        Returns:
            audio_path: Path to generated audio file
        """
        result = model_demo(
            audio_duration=duration,
            prompt=tags,
            lyrics=lyrics,
            infer_step=infer_steps,
            guidance_scale=guidance_scale,
            scheduler_type="euler",
            cfg_type="apg",
            omega_scale=10.0,
            manual_seeds=None,
            guidance_interval=0.5,
            guidance_interval_decay=0.0,
            min_guidance_scale=3.0,
            use_erg_tag=True,
            use_erg_lyric=False,
            use_erg_diffusion=True,
            oss_steps=None,
            guidance_scale_text=0.0,
            guidance_scale_lyric=0.0,
            audio2audio_enable=False,
            ref_audio_strength=0.5,
            ref_audio_input=None,
            lora_name_or_path="none"
        )
        # Return the audio file path
        if result and len(result) > 0:
            return result[0]  # Return first audio output (now always 24kHz WAV)
        return None

    # Use simplified UI
    demo = create_simple_ui(
        text2music_process_func=model_demo.__call__
    )
    
    # Add API endpoint to the demo
    demo.api_open = True
    
    demo.queue(default_concurrency_limit=8).launch(
        server_name=args.server_name,
        server_port=args.port,
        share=args.share
    )


if __name__ == "__main__":
    main(args)