magenta-retry / one_shot_generation.py
thecollabagepatch's picture
more cleanup in one_shot_generation to avoid stale context
dd42331
"""
One-shot music generation functions for MagentaRT.
This module contains the core generation functions extracted from the main app
that can be used independently for single-shot music generation tasks.
"""
import math
import numpy as np
from magenta_rt import audio as au
from utils import (
match_loudness_to_reference,
stitch_generated,
hard_trim_seconds,
apply_micro_fades,
make_bar_aligned_context,
take_bar_aligned_tail
)
def generate_loop_continuation_with_mrt(
mrt,
input_wav_path: str,
bpm: float,
extra_styles=None,
style_weights=None,
bars: int = 8,
beats_per_bar: int = 4,
loop_weight: float = 1.0,
loudness_mode: str = "auto",
loudness_headroom_db: float = 1.0,
intro_bars_to_drop: int = 0,
progress_cb=None
):
"""
Generate a continuation of an input loop using MagentaRT.
"""
# ===== CRITICAL FIX: Force codec state isolation =====
# Create a completely isolated encoding session to prevent
# audio from previous generations bleeding into this one
# Save original codec state (if any)
original_codec_state = {}
codec_attrs_to_clear = [
'_encode_state', '_decode_state',
'_last_encoded', '_last_decoded',
'_encoder_cache', '_decoder_cache',
'_buffer', '_frame_buffer'
]
for attr in codec_attrs_to_clear:
if hasattr(mrt.codec, attr):
original_codec_state[attr] = getattr(mrt.codec, attr)
setattr(mrt.codec, attr, None)
# Also clear any MRT-level generation state
mrt_attrs_to_clear = ['_last_state', '_generation_cache']
for attr in mrt_attrs_to_clear:
if hasattr(mrt, attr):
original_codec_state[f'mrt_{attr}'] = getattr(mrt, attr)
setattr(mrt, attr, None)
try:
# ============================================================
# Load & prep - Force FRESH file read (no caching)
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
# CRITICAL: Create a detached copy to prevent reference issues
loop = au.Waveform(
loop.samples.copy(), # Force array copy
loop.sample_rate
)
# Use tail for context
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
# CRITICAL: Another detached copy before encoding
loop_for_context = au.Waveform(
loop_for_context.samples.copy(),
loop_for_context.sample_rate
)
# Force fresh encoding with explicit copy flags
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True)
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
# CRITICAL: Ensure tokens are not a view
tokens = np.array(tokens, dtype=np.int32, copy=True, order='C')
# Bar-aligned token window
context_tokens = make_bar_aligned_context(
tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
)
# CRITICAL: Force contiguous memory layout
context_tokens = np.ascontiguousarray(context_tokens, dtype=np.int32)
# Create completely fresh state
state = mrt.init_state()
state.context_tokens = context_tokens
# STYLE embed - force fresh
loop_embed = mrt.embed_style(loop_for_context)
embeds, weights = [np.array(loop_embed, copy=True)], [float(loop_weight)]
if extra_styles:
for i, s in enumerate(extra_styles):
if s.strip():
e = mrt.embed_style(s.strip())
embeds.append(np.array(e, copy=True))
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
weights.append(float(w))
wsum = float(sum(weights)) or 1.0
weights = [w / wsum for w in weights]
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0)
combined_style = np.ascontiguousarray(combined_style, dtype=np.float32)
# --- Length math ---
seconds_per_bar = beats_per_bar * (60.0 / bpm)
total_secs = bars * seconds_per_bar
drop_bars = max(0, int(intro_bars_to_drop))
drop_secs = min(drop_bars, bars) * seconds_per_bar
gen_total_secs = total_secs + drop_secs
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
if progress_cb:
progress_cb(0, steps)
# Generate with state isolation
chunks = []
for i in range(steps):
wav, state = mrt.generate_chunk(state=state, style=combined_style)
# Force copy the waveform samples to prevent reference issues
wav = au.Waveform(wav.samples.copy(), wav.sample_rate)
chunks.append(wav)
if progress_cb:
progress_cb(i + 1, steps)
# Rest unchanged...
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
stitched = hard_trim_seconds(stitched, gen_total_secs)
if drop_secs > 0:
n_drop = int(round(drop_secs * stitched.sample_rate))
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
out = hard_trim_seconds(stitched, total_secs)
out, loud_stats = apply_barwise_loudness_match(
out=out,
ref_loop=loop,
bpm=bpm,
beats_per_bar=beats_per_bar,
method=loudness_mode,
headroom_db=loudness_headroom_db,
smooth_ms=50,
)
apply_micro_fades(out, 5)
return out, loud_stats
finally:
# ===== CLEANUP: Clear codec state after generation =====
# This prevents audio from THIS generation leaking into the NEXT one
for attr in codec_attrs_to_clear:
if hasattr(mrt.codec, attr):
setattr(mrt.codec, attr, None)
for attr in mrt_attrs_to_clear:
if hasattr(mrt, attr):
setattr(mrt, attr, None)
# =======================================================
def generate_style_only_with_mrt(
mrt,
bpm: float,
bars: int = 8,
beats_per_bar: int = 4,
styles: str = "warmup",
style_weights: str = "",
intro_bars_to_drop: int = 0,
):
"""
Style-only, bar-aligned generation using a silent context (no input audio).
Returns: (au.Waveform out, dict loud_stats_or_None)
"""
# ---- Build a 10s silent context, tokenized for the model ----
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
sr = int(mrt.sample_rate)
silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
tokens_full = mrt.codec.encode(silent).astype(np.int32)
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
state = mrt.init_state()
state.context_tokens = tokens
# ---- Style vector (text prompts only, normalized weights) ----
prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
if not prompts:
prompts = ["warmup"]
sw = [float(x) for x in style_weights.split(",")] if style_weights else []
embeds, weights = [], []
for i, p in enumerate(prompts):
embeds.append(mrt.embed_style(p))
weights.append(sw[i] if i < len(sw) else 1.0)
wsum = float(sum(weights)) or 1.0
weights = [w / wsum for w in weights]
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
# ---- Target length math ----
seconds_per_bar = beats_per_bar * (60.0 / bpm)
total_secs = bars * seconds_per_bar
drop_bars = max(0, int(intro_bars_to_drop))
drop_secs = min(drop_bars, bars) * seconds_per_bar
gen_total_secs = total_secs + drop_secs
# ~2.0s chunk length from model config
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
# Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
chunks = []
for _ in range(steps):
wav, state = mrt.generate_chunk(state=state, style=style_vec)
chunks.append(wav)
# Stitch & trim to exact musical length
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
stitched = hard_trim_seconds(stitched, gen_total_secs)
if drop_secs > 0:
n_drop = int(round(drop_secs * stitched.sample_rate))
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
out = hard_trim_seconds(stitched, total_secs)
out = out.peak_normalize(0.95)
apply_micro_fades(out, 5)
return out, None # loudness stats not applicable (no reference)
# loudness matching helper for /generate:
def apply_barwise_loudness_match(
out: au.Waveform,
ref_loop: au.Waveform,
*,
bpm: float,
beats_per_bar: int,
method: str = "auto",
headroom_db: float = 1.0,
smooth_ms: int = 50,
) -> tuple[au.Waveform, dict]:
"""
Bar-locked loudness matching that establishes the correct starting level
then maintains consistency. Only the first bar is matched to the reference;
subsequent bars use the same gain to maintain relative dynamics.
"""
sr = int(out.sample_rate)
spb = (60.0 / float(bpm)) * int(beats_per_bar)
bar_len = int(round(spb * sr))
y = out.samples.astype(np.float32, copy=False)
if y.ndim == 1: y = y[:, None]
if ref_loop.sample_rate != sr:
ref = ref_loop.resample(sr).as_stereo().samples.astype(np.float32, copy=False)
else:
ref = ref_loop.as_stereo().samples.astype(np.float32, copy=False)
if ref.ndim == 1: ref = ref[:, None]
if ref.shape[1] == 1: ref = np.repeat(ref, 2, axis=1)
from utils import match_loudness_to_reference
# Measure reference loudness once
ref_bar_len = min(ref.shape[0], bar_len)
ref_bar = au.Waveform(ref[:ref_bar_len], sr)
gains_db = []
out_adj = y.copy()
need = y.shape[0]
n_bars = max(1, int(np.ceil(need / float(bar_len))))
ramp = int(max(0, round(smooth_ms * sr / 1000.0)))
min_lufs_samples = int(0.4 * sr)
# Calculate gain from bar 0 matching
first_bar_gain_linear = 1.0
for i in range(n_bars):
s = i * bar_len
e = min(need, s + bar_len)
if e <= s:
break
bar_samples = e - s
tgt_bar = au.Waveform(y[s:e], sr) # Always read from ORIGINAL
# First bar: match to reference to establish gain
if i == 0:
effective_method = "rms" if bar_samples < min_lufs_samples else method
matched_bar, stats = match_loudness_to_reference(
ref_bar, tgt_bar, method=effective_method, headroom_db=headroom_db
)
# Calculate the linear gain that was applied
eps = 1e-12
first_bar_gain_linear = float(np.sqrt(
(np.mean(matched_bar.samples**2) + eps) /
(np.mean(tgt_bar.samples**2) + eps)
))
g = matched_bar.samples.astype(np.float32, copy=False)
else:
# Subsequent bars: apply the same gain from bar 0
g = (tgt_bar.samples * first_bar_gain_linear).astype(np.float32, copy=False)
# Calculate gain in dB for stats
if tgt_bar.samples.size > 0:
eps = 1e-12
g_lin = float(np.sqrt((np.mean(g**2) + eps) / (np.mean(tgt_bar.samples**2) + eps)))
else:
g_lin = 1.0
gains_db.append(20.0 * np.log10(max(g_lin, 1e-6)))
# Apply with ramp for smoothness
if i > 0 and ramp > 0:
ramp_len = min(ramp, e - s)
t = np.linspace(0.0, 1.0, ramp_len, dtype=np.float32)[:, None]
out_adj[s:s+ramp_len] = (1.0 - t) * out_adj[s:s+ramp_len] + t * g[:ramp_len]
out_adj[s+ramp_len:e] = g[ramp_len:e-s]
else:
out_adj[s:e] = g
out.samples = out_adj.astype(np.float32, copy=False)
return out, {"per_bar_gain_db": gains_db}