|
|
""" |
|
|
One-shot music generation functions for MagentaRT. |
|
|
|
|
|
This module contains the core generation functions extracted from the main app |
|
|
that can be used independently for single-shot music generation tasks. |
|
|
""" |
|
|
import math |
|
|
import numpy as np |
|
|
from magenta_rt import audio as au |
|
|
from utils import ( |
|
|
match_loudness_to_reference, |
|
|
stitch_generated, |
|
|
hard_trim_seconds, |
|
|
apply_micro_fades, |
|
|
make_bar_aligned_context, |
|
|
take_bar_aligned_tail |
|
|
) |
|
|
|
|
|
|
|
|
def generate_loop_continuation_with_mrt( |
|
|
mrt, |
|
|
input_wav_path: str, |
|
|
bpm: float, |
|
|
extra_styles=None, |
|
|
style_weights=None, |
|
|
bars: int = 8, |
|
|
beats_per_bar: int = 4, |
|
|
loop_weight: float = 1.0, |
|
|
loudness_mode: str = "auto", |
|
|
loudness_headroom_db: float = 1.0, |
|
|
intro_bars_to_drop: int = 0, |
|
|
progress_cb=None |
|
|
): |
|
|
""" |
|
|
Generate a continuation of an input loop using MagentaRT. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_codec_state = {} |
|
|
codec_attrs_to_clear = [ |
|
|
'_encode_state', '_decode_state', |
|
|
'_last_encoded', '_last_decoded', |
|
|
'_encoder_cache', '_decoder_cache', |
|
|
'_buffer', '_frame_buffer' |
|
|
] |
|
|
|
|
|
for attr in codec_attrs_to_clear: |
|
|
if hasattr(mrt.codec, attr): |
|
|
original_codec_state[attr] = getattr(mrt.codec, attr) |
|
|
setattr(mrt.codec, attr, None) |
|
|
|
|
|
|
|
|
mrt_attrs_to_clear = ['_last_state', '_generation_cache'] |
|
|
for attr in mrt_attrs_to_clear: |
|
|
if hasattr(mrt, attr): |
|
|
original_codec_state[f'mrt_{attr}'] = getattr(mrt, attr) |
|
|
setattr(mrt, attr, None) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() |
|
|
|
|
|
|
|
|
loop = au.Waveform( |
|
|
loop.samples.copy(), |
|
|
loop.sample_rate |
|
|
) |
|
|
|
|
|
|
|
|
codec_fps = float(mrt.codec.frame_rate) |
|
|
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
|
|
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) |
|
|
|
|
|
|
|
|
loop_for_context = au.Waveform( |
|
|
loop_for_context.samples.copy(), |
|
|
loop_for_context.sample_rate |
|
|
) |
|
|
|
|
|
|
|
|
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True) |
|
|
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] |
|
|
|
|
|
|
|
|
tokens = np.array(tokens, dtype=np.int32, copy=True, order='C') |
|
|
|
|
|
|
|
|
context_tokens = make_bar_aligned_context( |
|
|
tokens, bpm=bpm, fps=float(mrt.codec.frame_rate), |
|
|
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar |
|
|
) |
|
|
|
|
|
|
|
|
context_tokens = np.ascontiguousarray(context_tokens, dtype=np.int32) |
|
|
|
|
|
|
|
|
state = mrt.init_state() |
|
|
state.context_tokens = context_tokens |
|
|
|
|
|
|
|
|
loop_embed = mrt.embed_style(loop_for_context) |
|
|
embeds, weights = [np.array(loop_embed, copy=True)], [float(loop_weight)] |
|
|
|
|
|
if extra_styles: |
|
|
for i, s in enumerate(extra_styles): |
|
|
if s.strip(): |
|
|
e = mrt.embed_style(s.strip()) |
|
|
embeds.append(np.array(e, copy=True)) |
|
|
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 |
|
|
weights.append(float(w)) |
|
|
|
|
|
wsum = float(sum(weights)) or 1.0 |
|
|
weights = [w / wsum for w in weights] |
|
|
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0) |
|
|
combined_style = np.ascontiguousarray(combined_style, dtype=np.float32) |
|
|
|
|
|
|
|
|
seconds_per_bar = beats_per_bar * (60.0 / bpm) |
|
|
total_secs = bars * seconds_per_bar |
|
|
drop_bars = max(0, int(intro_bars_to_drop)) |
|
|
drop_secs = min(drop_bars, bars) * seconds_per_bar |
|
|
gen_total_secs = total_secs + drop_secs |
|
|
|
|
|
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate |
|
|
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 |
|
|
|
|
|
if progress_cb: |
|
|
progress_cb(0, steps) |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
for i in range(steps): |
|
|
wav, state = mrt.generate_chunk(state=state, style=combined_style) |
|
|
|
|
|
wav = au.Waveform(wav.samples.copy(), wav.sample_rate) |
|
|
chunks.append(wav) |
|
|
if progress_cb: |
|
|
progress_cb(i + 1, steps) |
|
|
|
|
|
|
|
|
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() |
|
|
stitched = hard_trim_seconds(stitched, gen_total_secs) |
|
|
|
|
|
if drop_secs > 0: |
|
|
n_drop = int(round(drop_secs * stitched.sample_rate)) |
|
|
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) |
|
|
|
|
|
out = hard_trim_seconds(stitched, total_secs) |
|
|
|
|
|
out, loud_stats = apply_barwise_loudness_match( |
|
|
out=out, |
|
|
ref_loop=loop, |
|
|
bpm=bpm, |
|
|
beats_per_bar=beats_per_bar, |
|
|
method=loudness_mode, |
|
|
headroom_db=loudness_headroom_db, |
|
|
smooth_ms=50, |
|
|
) |
|
|
|
|
|
apply_micro_fades(out, 5) |
|
|
|
|
|
return out, loud_stats |
|
|
|
|
|
finally: |
|
|
|
|
|
|
|
|
for attr in codec_attrs_to_clear: |
|
|
if hasattr(mrt.codec, attr): |
|
|
setattr(mrt.codec, attr, None) |
|
|
|
|
|
for attr in mrt_attrs_to_clear: |
|
|
if hasattr(mrt, attr): |
|
|
setattr(mrt, attr, None) |
|
|
|
|
|
|
|
|
|
|
|
def generate_style_only_with_mrt( |
|
|
mrt, |
|
|
bpm: float, |
|
|
bars: int = 8, |
|
|
beats_per_bar: int = 4, |
|
|
styles: str = "warmup", |
|
|
style_weights: str = "", |
|
|
intro_bars_to_drop: int = 0, |
|
|
): |
|
|
""" |
|
|
Style-only, bar-aligned generation using a silent context (no input audio). |
|
|
Returns: (au.Waveform out, dict loud_stats_or_None) |
|
|
""" |
|
|
|
|
|
codec_fps = float(mrt.codec.frame_rate) |
|
|
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
|
|
sr = int(mrt.sample_rate) |
|
|
|
|
|
silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr) |
|
|
tokens_full = mrt.codec.encode(silent).astype(np.int32) |
|
|
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] |
|
|
|
|
|
state = mrt.init_state() |
|
|
state.context_tokens = tokens |
|
|
|
|
|
|
|
|
prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] |
|
|
if not prompts: |
|
|
prompts = ["warmup"] |
|
|
sw = [float(x) for x in style_weights.split(",")] if style_weights else [] |
|
|
embeds, weights = [], [] |
|
|
for i, p in enumerate(prompts): |
|
|
embeds.append(mrt.embed_style(p)) |
|
|
weights.append(sw[i] if i < len(sw) else 1.0) |
|
|
wsum = float(sum(weights)) or 1.0 |
|
|
weights = [w / wsum for w in weights] |
|
|
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32) |
|
|
|
|
|
|
|
|
seconds_per_bar = beats_per_bar * (60.0 / bpm) |
|
|
total_secs = bars * seconds_per_bar |
|
|
drop_bars = max(0, int(intro_bars_to_drop)) |
|
|
drop_secs = min(drop_bars, bars) * seconds_per_bar |
|
|
gen_total_secs = total_secs + drop_secs |
|
|
|
|
|
|
|
|
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate) |
|
|
|
|
|
|
|
|
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 |
|
|
|
|
|
chunks = [] |
|
|
for _ in range(steps): |
|
|
wav, state = mrt.generate_chunk(state=state, style=style_vec) |
|
|
chunks.append(wav) |
|
|
|
|
|
|
|
|
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() |
|
|
stitched = hard_trim_seconds(stitched, gen_total_secs) |
|
|
|
|
|
if drop_secs > 0: |
|
|
n_drop = int(round(drop_secs * stitched.sample_rate)) |
|
|
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) |
|
|
|
|
|
out = hard_trim_seconds(stitched, total_secs) |
|
|
out = out.peak_normalize(0.95) |
|
|
apply_micro_fades(out, 5) |
|
|
|
|
|
return out, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_barwise_loudness_match( |
|
|
out: au.Waveform, |
|
|
ref_loop: au.Waveform, |
|
|
*, |
|
|
bpm: float, |
|
|
beats_per_bar: int, |
|
|
method: str = "auto", |
|
|
headroom_db: float = 1.0, |
|
|
smooth_ms: int = 50, |
|
|
) -> tuple[au.Waveform, dict]: |
|
|
""" |
|
|
Bar-locked loudness matching that establishes the correct starting level |
|
|
then maintains consistency. Only the first bar is matched to the reference; |
|
|
subsequent bars use the same gain to maintain relative dynamics. |
|
|
""" |
|
|
sr = int(out.sample_rate) |
|
|
spb = (60.0 / float(bpm)) * int(beats_per_bar) |
|
|
bar_len = int(round(spb * sr)) |
|
|
|
|
|
y = out.samples.astype(np.float32, copy=False) |
|
|
if y.ndim == 1: y = y[:, None] |
|
|
|
|
|
if ref_loop.sample_rate != sr: |
|
|
ref = ref_loop.resample(sr).as_stereo().samples.astype(np.float32, copy=False) |
|
|
else: |
|
|
ref = ref_loop.as_stereo().samples.astype(np.float32, copy=False) |
|
|
|
|
|
if ref.ndim == 1: ref = ref[:, None] |
|
|
if ref.shape[1] == 1: ref = np.repeat(ref, 2, axis=1) |
|
|
|
|
|
from utils import match_loudness_to_reference |
|
|
|
|
|
|
|
|
ref_bar_len = min(ref.shape[0], bar_len) |
|
|
ref_bar = au.Waveform(ref[:ref_bar_len], sr) |
|
|
|
|
|
gains_db = [] |
|
|
out_adj = y.copy() |
|
|
need = y.shape[0] |
|
|
n_bars = max(1, int(np.ceil(need / float(bar_len)))) |
|
|
ramp = int(max(0, round(smooth_ms * sr / 1000.0))) |
|
|
min_lufs_samples = int(0.4 * sr) |
|
|
|
|
|
|
|
|
first_bar_gain_linear = 1.0 |
|
|
|
|
|
for i in range(n_bars): |
|
|
s = i * bar_len |
|
|
e = min(need, s + bar_len) |
|
|
if e <= s: |
|
|
break |
|
|
|
|
|
bar_samples = e - s |
|
|
tgt_bar = au.Waveform(y[s:e], sr) |
|
|
|
|
|
|
|
|
if i == 0: |
|
|
effective_method = "rms" if bar_samples < min_lufs_samples else method |
|
|
matched_bar, stats = match_loudness_to_reference( |
|
|
ref_bar, tgt_bar, method=effective_method, headroom_db=headroom_db |
|
|
) |
|
|
|
|
|
|
|
|
eps = 1e-12 |
|
|
first_bar_gain_linear = float(np.sqrt( |
|
|
(np.mean(matched_bar.samples**2) + eps) / |
|
|
(np.mean(tgt_bar.samples**2) + eps) |
|
|
)) |
|
|
g = matched_bar.samples.astype(np.float32, copy=False) |
|
|
else: |
|
|
|
|
|
g = (tgt_bar.samples * first_bar_gain_linear).astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
if tgt_bar.samples.size > 0: |
|
|
eps = 1e-12 |
|
|
g_lin = float(np.sqrt((np.mean(g**2) + eps) / (np.mean(tgt_bar.samples**2) + eps))) |
|
|
else: |
|
|
g_lin = 1.0 |
|
|
gains_db.append(20.0 * np.log10(max(g_lin, 1e-6))) |
|
|
|
|
|
|
|
|
if i > 0 and ramp > 0: |
|
|
ramp_len = min(ramp, e - s) |
|
|
t = np.linspace(0.0, 1.0, ramp_len, dtype=np.float32)[:, None] |
|
|
out_adj[s:s+ramp_len] = (1.0 - t) * out_adj[s:s+ramp_len] + t * g[:ramp_len] |
|
|
out_adj[s+ramp_len:e] = g[ramp_len:e-s] |
|
|
else: |
|
|
out_adj[s:e] = g |
|
|
|
|
|
out.samples = out_adj.astype(np.float32, copy=False) |
|
|
return out, {"per_bar_gain_db": gains_db} |