Commit
·
8ba62a7
1
Parent(s):
5af7cde
braindead progress updates for /generate to be used in JUCE
Browse files- app.py +155 -62
- one_shot_generation.py +9 -2
app.py
CHANGED
|
@@ -205,6 +205,8 @@ _patch_t5x_for_gpu_coords()
|
|
| 205 |
jam_registry: dict[str, JamWorker] = {}
|
| 206 |
jam_lock = threading.Lock()
|
| 207 |
|
|
|
|
|
|
|
| 208 |
@contextmanager
|
| 209 |
def mrt_overrides(mrt, **kwargs):
|
| 210 |
"""Temporarily set attributes on MRT if they exist; restore after."""
|
|
@@ -331,6 +333,33 @@ app.add_middleware(
|
|
| 331 |
_MRT = None
|
| 332 |
_MRT_LOCK = threading.Lock()
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
def get_mrt():
|
| 335 |
global _MRT
|
| 336 |
if _MRT is None:
|
|
@@ -441,6 +470,8 @@ def _boot():
|
|
| 441 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 442 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 443 |
|
|
|
|
|
|
|
| 444 |
@app.get("/model/status")
|
| 445 |
def model_status():
|
| 446 |
mrt = get_mrt()
|
|
@@ -674,7 +705,9 @@ def model_select(req: ModelSelect):
|
|
| 674 |
# one-shot generation
|
| 675 |
# ----------------------------
|
| 676 |
|
| 677 |
-
|
|
|
|
|
|
|
| 678 |
|
| 679 |
@app.post("/generate")
|
| 680 |
def generate(
|
|
@@ -691,76 +724,136 @@ def generate(
|
|
| 691 |
temperature: float = Form(1.1),
|
| 692 |
topk: int = Form(40),
|
| 693 |
target_sample_rate: int | None = Form(None),
|
| 694 |
-
intro_bars_to_drop: int = Form(0),
|
|
|
|
| 695 |
):
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
if not data:
|
| 699 |
-
return {"error": "Empty file"}
|
| 700 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 701 |
-
tmp.write(data)
|
| 702 |
-
tmp_path = tmp.name
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
mrt,
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
)
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
| 736 |
-
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
| 737 |
-
expected_secs = float(bars) * seconds_per_bar
|
| 738 |
-
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
|
| 739 |
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
|
| 744 |
-
# 4) Metadata
|
| 745 |
-
metadata = {
|
| 746 |
-
"bpm": int(round(bpm)),
|
| 747 |
-
"bars": int(bars),
|
| 748 |
-
"beats_per_bar": int(beats_per_bar),
|
| 749 |
-
"styles": extra_styles,
|
| 750 |
-
"style_weights": weights,
|
| 751 |
-
"loop_weight": loop_weight,
|
| 752 |
-
"loudness": loud_stats,
|
| 753 |
-
"sample_rate": int(target_sr),
|
| 754 |
-
"channels": int(channels),
|
| 755 |
-
"crossfade_seconds": mrt.config.crossfade_length,
|
| 756 |
-
"total_samples": int(total_samples),
|
| 757 |
-
"seconds_per_bar": seconds_per_bar,
|
| 758 |
-
"loop_duration_seconds": loop_duration_seconds,
|
| 759 |
-
"guidance_weight": guidance_weight,
|
| 760 |
-
"temperature": temperature,
|
| 761 |
-
"topk": topk,
|
| 762 |
-
}
|
| 763 |
-
return {"audio_base64": audio_b64, "metadata": metadata}
|
| 764 |
|
| 765 |
# new endpoint to return a bar-aligned chunk without the need for combined audio
|
| 766 |
|
|
|
|
| 205 |
jam_registry: dict[str, JamWorker] = {}
|
| 206 |
jam_lock = threading.Lock()
|
| 207 |
|
| 208 |
+
|
| 209 |
+
|
| 210 |
@contextmanager
|
| 211 |
def mrt_overrides(mrt, **kwargs):
|
| 212 |
"""Temporarily set attributes on MRT if they exist; restore after."""
|
|
|
|
| 333 |
_MRT = None
|
| 334 |
_MRT_LOCK = threading.Lock()
|
| 335 |
|
| 336 |
+
_PROGRESS = {}
|
| 337 |
+
_PROGRESS_LOCK = threading.Lock()
|
| 338 |
+
|
| 339 |
+
def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
|
| 340 |
+
if not req_id:
|
| 341 |
+
return
|
| 342 |
+
with _PROGRESS_LOCK:
|
| 343 |
+
_PROGRESS[req_id] = {
|
| 344 |
+
"n": int(n),
|
| 345 |
+
"total": int(total),
|
| 346 |
+
"percent": int(round(100.0 * max(0, min(n, total)) / max(1, total))),
|
| 347 |
+
"stage": stage,
|
| 348 |
+
"ts": time.time(),
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
def _progress_done(req_id: str):
|
| 352 |
+
if not req_id:
|
| 353 |
+
return
|
| 354 |
+
with _PROGRESS_LOCK:
|
| 355 |
+
st = _PROGRESS.get(req_id, {})
|
| 356 |
+
total = st.get("total") or st.get("n") or 1
|
| 357 |
+
_PROGRESS[req_id] = {"n": total, "total": total, "percent": 100, "stage": "done", "ts": time.time()}
|
| 358 |
+
|
| 359 |
+
def _progress_get(req_id: str):
|
| 360 |
+
with _PROGRESS_LOCK:
|
| 361 |
+
return _PROGRESS.get(req_id, {"percent": 0, "stage": "pending"})
|
| 362 |
+
|
| 363 |
def get_mrt():
|
| 364 |
global _MRT
|
| 365 |
if _MRT is None:
|
|
|
|
| 470 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 471 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 472 |
|
| 473 |
+
|
| 474 |
+
|
| 475 |
@app.get("/model/status")
|
| 476 |
def model_status():
|
| 477 |
mrt = get_mrt()
|
|
|
|
| 705 |
# one-shot generation
|
| 706 |
# ----------------------------
|
| 707 |
|
| 708 |
+
@app.get("/progress")
|
| 709 |
+
def progress(request_id: str):
|
| 710 |
+
return _progress_get(request_id)
|
| 711 |
|
| 712 |
@app.post("/generate")
|
| 713 |
def generate(
|
|
|
|
| 724 |
temperature: float = Form(1.1),
|
| 725 |
topk: int = Form(40),
|
| 726 |
target_sample_rate: int | None = Form(None),
|
| 727 |
+
intro_bars_to_drop: int = Form(0),
|
| 728 |
+
request_id: str = Form(None),
|
| 729 |
):
|
| 730 |
+
req_id = request_id or str(uuid.uuid4())
|
| 731 |
+
tmp_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
+
try:
|
| 734 |
+
# 0) Read file -> tmp wav
|
| 735 |
+
data = loop_audio.file.read()
|
| 736 |
+
if not data:
|
| 737 |
+
# finalize progress as error and return
|
| 738 |
+
with _PROGRESS_LOCK:
|
| 739 |
+
_PROGRESS[req_id] = {
|
| 740 |
+
"percent": 100,
|
| 741 |
+
"stage": "error",
|
| 742 |
+
"error": "Empty file",
|
| 743 |
+
"ts": time.time(),
|
| 744 |
+
}
|
| 745 |
+
return {"error": "Empty file", "request_id": req_id}
|
| 746 |
|
| 747 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 748 |
+
tmp.write(data)
|
| 749 |
+
tmp_path = tmp.name
|
| 750 |
+
|
| 751 |
+
# 1) Parse styles + weights
|
| 752 |
+
extra_styles = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
|
| 753 |
+
weights = [float(x) for x in style_weights.split(",")] if style_weights else None
|
| 754 |
+
|
| 755 |
+
# 2) Get model and apply per-request overrides
|
| 756 |
+
mrt = get_mrt()
|
| 757 |
+
with mrt_overrides(
|
| 758 |
mrt,
|
| 759 |
+
guidance_weight=guidance_weight,
|
| 760 |
+
temperature=temperature,
|
| 761 |
+
topk=topk,
|
| 762 |
+
):
|
| 763 |
+
# progress callback (called from the generator loop)
|
| 764 |
+
def on_chunk(i, total):
|
| 765 |
+
_progress_update(req_id, i, total, stage="generating")
|
| 766 |
+
|
| 767 |
+
# 2a) (optional) emit initial 0% once steps are known:
|
| 768 |
+
# We'll do this inside the generator right after steps is computed.
|
| 769 |
+
wav, loud_stats = generate_loop_continuation_with_mrt(
|
| 770 |
+
mrt,
|
| 771 |
+
input_wav_path=tmp_path,
|
| 772 |
+
bpm=bpm,
|
| 773 |
+
extra_styles=extra_styles,
|
| 774 |
+
style_weights=weights,
|
| 775 |
+
bars=bars,
|
| 776 |
+
beats_per_bar=beats_per_bar,
|
| 777 |
+
loop_weight=loop_weight,
|
| 778 |
+
loudness_mode=loudness_mode,
|
| 779 |
+
loudness_headroom_db=loudness_headroom_db,
|
| 780 |
+
intro_bars_to_drop=intro_bars_to_drop,
|
| 781 |
+
progress_cb=on_chunk,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# 3) Post-process stages (optional: expose sub-stages for nicer UI)
|
| 785 |
+
# Mark "postprocess" before we resample/snap/encode.
|
| 786 |
+
st = _PROGRESS_GET(req_id) if False else None # (placeholder so lints don't complain)
|
| 787 |
+
_progress_update(
|
| 788 |
+
req_id,
|
| 789 |
+
_progress_get(req_id).get("total", 1),
|
| 790 |
+
_progress_get(req_id).get("total", 1),
|
| 791 |
+
"postprocess",
|
| 792 |
)
|
| 793 |
|
| 794 |
+
# 3a) Determine SR
|
| 795 |
+
inp_info = sf.info(tmp_path)
|
| 796 |
+
input_sr = int(inp_info.samplerate)
|
| 797 |
+
target_sr_val = int(target_sample_rate or input_sr)
|
| 798 |
+
|
| 799 |
+
# 3b) Convert SR + snap to exact bars
|
| 800 |
+
cur_sr = int(mrt.sample_rate)
|
| 801 |
+
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
| 802 |
+
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
| 803 |
+
expected_secs = float(bars) * seconds_per_bar
|
| 804 |
+
|
| 805 |
+
# (optional) sub-stage
|
| 806 |
+
_progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "resample_and_snap")
|
| 807 |
+
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr_val, seconds=expected_secs)
|
| 808 |
+
|
| 809 |
+
# 3c) Encode WAV -> base64
|
| 810 |
+
_progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "encode")
|
| 811 |
+
audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr_val)
|
| 812 |
+
loop_duration_seconds = total_samples / float(target_sr_val)
|
| 813 |
+
|
| 814 |
+
# 4) Metadata
|
| 815 |
+
metadata = {
|
| 816 |
+
"bpm": int(round(bpm)),
|
| 817 |
+
"bars": int(bars),
|
| 818 |
+
"beats_per_bar": int(beats_per_bar),
|
| 819 |
+
"styles": extra_styles,
|
| 820 |
+
"style_weights": weights,
|
| 821 |
+
"loop_weight": loop_weight,
|
| 822 |
+
"loudness": loud_stats,
|
| 823 |
+
"sample_rate": int(target_sr_val),
|
| 824 |
+
"channels": int(channels),
|
| 825 |
+
"crossfade_seconds": mrt.config.crossfade_length,
|
| 826 |
+
"total_samples": int(total_samples),
|
| 827 |
+
"seconds_per_bar": seconds_per_bar,
|
| 828 |
+
"loop_duration_seconds": loop_duration_seconds,
|
| 829 |
+
"guidance_weight": guidance_weight,
|
| 830 |
+
"temperature": temperature,
|
| 831 |
+
"topk": topk,
|
| 832 |
+
}
|
| 833 |
|
| 834 |
+
_progress_done(req_id)
|
| 835 |
+
return {"audio_base64": audio_b64, "metadata": metadata, "request_id": req_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
| 837 |
+
except Exception as e:
|
| 838 |
+
# Flip to error state so the UI stops polling and can show a message
|
| 839 |
+
with _PROGRESS_LOCK:
|
| 840 |
+
_PROGRESS[req_id] = {
|
| 841 |
+
"percent": 100,
|
| 842 |
+
"stage": "error",
|
| 843 |
+
"error": str(e),
|
| 844 |
+
"ts": time.time(),
|
| 845 |
+
}
|
| 846 |
+
# Re-raise so FastAPI returns a 500 (or your exception handler formats it)
|
| 847 |
+
raise
|
| 848 |
+
|
| 849 |
+
finally:
|
| 850 |
+
# Clean up temp file
|
| 851 |
+
if tmp_path:
|
| 852 |
+
try:
|
| 853 |
+
os.unlink(tmp_path)
|
| 854 |
+
except Exception:
|
| 855 |
+
pass
|
| 856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
# new endpoint to return a bar-aligned chunk without the need for combined audio
|
| 859 |
|
one_shot_generation.py
CHANGED
|
@@ -29,6 +29,7 @@ def generate_loop_continuation_with_mrt(
|
|
| 29 |
loudness_mode: str = "auto",
|
| 30 |
loudness_headroom_db: float = 1.0,
|
| 31 |
intro_bars_to_drop: int = 0,
|
|
|
|
| 32 |
):
|
| 33 |
"""
|
| 34 |
Generate a continuation of an input loop using MagentaRT.
|
|
@@ -45,6 +46,7 @@ def generate_loop_continuation_with_mrt(
|
|
| 45 |
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
|
| 46 |
loudness_headroom_db: Headroom in dB for peak limiting
|
| 47 |
intro_bars_to_drop: Number of intro bars to generate then drop
|
|
|
|
| 48 |
|
| 49 |
Returns:
|
| 50 |
Tuple of (au.Waveform output, dict loudness_stats)
|
|
@@ -90,13 +92,18 @@ def generate_loop_continuation_with_mrt(
|
|
| 90 |
|
| 91 |
# Chunk scheduling to cover gen_total_secs
|
| 92 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
| 93 |
-
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Generate
|
| 96 |
chunks = []
|
| 97 |
-
for
|
| 98 |
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
| 99 |
chunks.append(wav)
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Stitch continuous audio
|
| 102 |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
|
|
|
|
| 29 |
loudness_mode: str = "auto",
|
| 30 |
loudness_headroom_db: float = 1.0,
|
| 31 |
intro_bars_to_drop: int = 0,
|
| 32 |
+
progress_cb=None
|
| 33 |
):
|
| 34 |
"""
|
| 35 |
Generate a continuation of an input loop using MagentaRT.
|
|
|
|
| 46 |
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
|
| 47 |
loudness_headroom_db: Headroom in dB for peak limiting
|
| 48 |
intro_bars_to_drop: Number of intro bars to generate then drop
|
| 49 |
+
progress_cb: Braindead progress updates for JUCE
|
| 50 |
|
| 51 |
Returns:
|
| 52 |
Tuple of (au.Waveform output, dict loudness_stats)
|
|
|
|
| 92 |
|
| 93 |
# Chunk scheduling to cover gen_total_secs
|
| 94 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
| 95 |
+
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
|
| 96 |
+
|
| 97 |
+
if progress_cb:
|
| 98 |
+
progress_cb(0, steps) # announce total before first chunk
|
| 99 |
|
| 100 |
# Generate
|
| 101 |
chunks = []
|
| 102 |
+
for i in range(steps):
|
| 103 |
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
| 104 |
chunks.append(wav)
|
| 105 |
+
if progress_cb:
|
| 106 |
+
progress_cb(i + 1, steps) # <-- report chunk progress
|
| 107 |
|
| 108 |
# Stitch continuous audio
|
| 109 |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
|