File size: 12,680 Bytes
842a99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba62a7
842a99f
 
 
2446a8b
842a99f
dd42331
 
 
2446a8b
dd42331
 
 
 
 
 
 
 
2446a8b
dd42331
 
 
 
2446a8b
dd42331
 
 
 
 
 
2446a8b
dd42331
 
2446a8b
dd42331
 
2446a8b
dd42331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842a99f
dd42331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842a99f
dd42331
 
 
 
 
 
842a99f
dd42331
 
7457794
dd42331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842a99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78cac08
 
 
 
 
 
 
 
 
 
 
 
 
6d5b723
78cac08
 
6d5b723
 
7ae8a62
78cac08
 
 
 
 
 
 
6d5b723
78cac08
 
 
 
 
 
 
 
6d5b723
78cac08
7ae8a62
6d5b723
 
 
78cac08
 
6d5b723
78cac08
 
e87e83d
7ae8a62
 
 
 
78cac08
 
 
e87e83d
 
 
 
7ae8a62
78cac08
7ae8a62
6d5b723
 
 
 
 
7ae8a62
 
 
 
 
 
 
 
e87e83d
7ae8a62
 
e87e83d
7ae8a62
78cac08
 
 
 
 
 
 
6d5b723
78cac08
6d5b723
1355fb6
 
 
78cac08
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""
One-shot music generation functions for MagentaRT.

This module contains the core generation functions extracted from the main app
that can be used independently for single-shot music generation tasks.
"""
import math
import numpy as np
from magenta_rt import audio as au
from utils import (
    match_loudness_to_reference, 
    stitch_generated, 
    hard_trim_seconds,
    apply_micro_fades, 
    make_bar_aligned_context, 
    take_bar_aligned_tail
)


def generate_loop_continuation_with_mrt(
    mrt,
    input_wav_path: str,
    bpm: float,
    extra_styles=None,
    style_weights=None,
    bars: int = 8,
    beats_per_bar: int = 4,
    loop_weight: float = 1.0,
    loudness_mode: str = "auto",
    loudness_headroom_db: float = 1.0,
    intro_bars_to_drop: int = 0,
    progress_cb=None
):
    """
    Generate a continuation of an input loop using MagentaRT.
    """
    
    # ===== CRITICAL FIX: Force codec state isolation =====
    # Create a completely isolated encoding session to prevent
    # audio from previous generations bleeding into this one
    
    # Save original codec state (if any)
    original_codec_state = {}
    codec_attrs_to_clear = [
        '_encode_state', '_decode_state', 
        '_last_encoded', '_last_decoded',
        '_encoder_cache', '_decoder_cache',
        '_buffer', '_frame_buffer'
    ]
    
    for attr in codec_attrs_to_clear:
        if hasattr(mrt.codec, attr):
            original_codec_state[attr] = getattr(mrt.codec, attr)
            setattr(mrt.codec, attr, None)
    
    # Also clear any MRT-level generation state
    mrt_attrs_to_clear = ['_last_state', '_generation_cache']
    for attr in mrt_attrs_to_clear:
        if hasattr(mrt, attr):
            original_codec_state[f'mrt_{attr}'] = getattr(mrt, attr)
            setattr(mrt, attr, None)
    
    try:
        # ============================================================
        
        # Load & prep - Force FRESH file read (no caching)
        loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
        
        # CRITICAL: Create a detached copy to prevent reference issues
        loop = au.Waveform(
            loop.samples.copy(),  # Force array copy
            loop.sample_rate
        )

        # Use tail for context
        codec_fps   = float(mrt.codec.frame_rate)
        ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
        loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
        
        # CRITICAL: Another detached copy before encoding
        loop_for_context = au.Waveform(
            loop_for_context.samples.copy(),
            loop_for_context.sample_rate
        )

        # Force fresh encoding with explicit copy flags
        tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True)
        tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
        
        # CRITICAL: Ensure tokens are not a view
        tokens = np.array(tokens, dtype=np.int32, copy=True, order='C')

        # Bar-aligned token window
        context_tokens = make_bar_aligned_context(
            tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
            ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
        )
        
        # CRITICAL: Force contiguous memory layout
        context_tokens = np.ascontiguousarray(context_tokens, dtype=np.int32)
        
        # Create completely fresh state
        state = mrt.init_state()
        state.context_tokens = context_tokens

        # STYLE embed - force fresh
        loop_embed = mrt.embed_style(loop_for_context)
        embeds, weights = [np.array(loop_embed, copy=True)], [float(loop_weight)]
        
        if extra_styles:
            for i, s in enumerate(extra_styles):
                if s.strip():
                    e = mrt.embed_style(s.strip())
                    embeds.append(np.array(e, copy=True))
                    w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
                    weights.append(float(w))
        
        wsum = float(sum(weights)) or 1.0
        weights = [w / wsum for w in weights]
        combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0)
        combined_style = np.ascontiguousarray(combined_style, dtype=np.float32)

        # --- Length math ---
        seconds_per_bar = beats_per_bar * (60.0 / bpm)
        total_secs      = bars * seconds_per_bar
        drop_bars       = max(0, int(intro_bars_to_drop))
        drop_secs       = min(drop_bars, bars) * seconds_per_bar
        gen_total_secs  = total_secs + drop_secs

        chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
        steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1

        if progress_cb:
            progress_cb(0, steps)

        # Generate with state isolation
        chunks = []
        for i in range(steps):
            wav, state = mrt.generate_chunk(state=state, style=combined_style)
            # Force copy the waveform samples to prevent reference issues
            wav = au.Waveform(wav.samples.copy(), wav.sample_rate)
            chunks.append(wav)
            if progress_cb:
                progress_cb(i + 1, steps)

        # Rest unchanged...
        stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
        stitched = hard_trim_seconds(stitched, gen_total_secs)

        if drop_secs > 0:
            n_drop = int(round(drop_secs * stitched.sample_rate))
            stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)

        out = hard_trim_seconds(stitched, total_secs)

        out, loud_stats = apply_barwise_loudness_match(
            out=out,
            ref_loop=loop,
            bpm=bpm,
            beats_per_bar=beats_per_bar,
            method=loudness_mode,
            headroom_db=loudness_headroom_db,
            smooth_ms=50,
        )

        apply_micro_fades(out, 5)

        return out, loud_stats
        
    finally:
        # ===== CLEANUP: Clear codec state after generation =====
        # This prevents audio from THIS generation leaking into the NEXT one
        for attr in codec_attrs_to_clear:
            if hasattr(mrt.codec, attr):
                setattr(mrt.codec, attr, None)
        
        for attr in mrt_attrs_to_clear:
            if hasattr(mrt, attr):
                setattr(mrt, attr, None)
        # =======================================================


def generate_style_only_with_mrt(
    mrt,
    bpm: float,
    bars: int = 8,
    beats_per_bar: int = 4,
    styles: str = "warmup",
    style_weights: str = "",
    intro_bars_to_drop: int = 0,
):
    """
    Style-only, bar-aligned generation using a silent context (no input audio).
    Returns: (au.Waveform out, dict loud_stats_or_None)
    """
    # ---- Build a 10s silent context, tokenized for the model ----
    codec_fps   = float(mrt.codec.frame_rate)
    ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
    sr          = int(mrt.sample_rate)

    silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
    tokens_full = mrt.codec.encode(silent).astype(np.int32)
    tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]

    state = mrt.init_state()
    state.context_tokens = tokens

    # ---- Style vector (text prompts only, normalized weights) ----
    prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
    if not prompts:
        prompts = ["warmup"]
    sw = [float(x) for x in style_weights.split(",")] if style_weights else []
    embeds, weights = [], []
    for i, p in enumerate(prompts):
        embeds.append(mrt.embed_style(p))
        weights.append(sw[i] if i < len(sw) else 1.0)
    wsum = float(sum(weights)) or 1.0
    weights = [w / wsum for w in weights]
    style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)

    # ---- Target length math ----
    seconds_per_bar = beats_per_bar * (60.0 / bpm)
    total_secs      = bars * seconds_per_bar
    drop_bars       = max(0, int(intro_bars_to_drop))
    drop_secs       = min(drop_bars, bars) * seconds_per_bar
    gen_total_secs  = total_secs + drop_secs

    # ~2.0s chunk length from model config
    chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)

    # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
    steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1

    chunks = []
    for _ in range(steps):
        wav, state = mrt.generate_chunk(state=state, style=style_vec)
        chunks.append(wav)

    # Stitch & trim to exact musical length
    stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
    stitched = hard_trim_seconds(stitched, gen_total_secs)

    if drop_secs > 0:
        n_drop = int(round(drop_secs * stitched.sample_rate))
        stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)

    out = hard_trim_seconds(stitched, total_secs)
    out = out.peak_normalize(0.95)
    apply_micro_fades(out, 5)

    return out, None  # loudness stats not applicable (no reference)


# loudness matching helper for /generate:

def apply_barwise_loudness_match(
    out: au.Waveform,
    ref_loop: au.Waveform,
    *,
    bpm: float,
    beats_per_bar: int,
    method: str = "auto",
    headroom_db: float = 1.0,
    smooth_ms: int = 50,
) -> tuple[au.Waveform, dict]:
    """
    Bar-locked loudness matching that establishes the correct starting level
    then maintains consistency. Only the first bar is matched to the reference;
    subsequent bars use the same gain to maintain relative dynamics.
    """
    sr = int(out.sample_rate)
    spb = (60.0 / float(bpm)) * int(beats_per_bar)
    bar_len = int(round(spb * sr))

    y = out.samples.astype(np.float32, copy=False)
    if y.ndim == 1: y = y[:, None]
    
    if ref_loop.sample_rate != sr:
        ref = ref_loop.resample(sr).as_stereo().samples.astype(np.float32, copy=False)
    else:
        ref = ref_loop.as_stereo().samples.astype(np.float32, copy=False)

    if ref.ndim == 1: ref = ref[:, None]
    if ref.shape[1] == 1: ref = np.repeat(ref, 2, axis=1)

    from utils import match_loudness_to_reference

    # Measure reference loudness once
    ref_bar_len = min(ref.shape[0], bar_len)
    ref_bar = au.Waveform(ref[:ref_bar_len], sr)
    
    gains_db = []
    out_adj = y.copy()
    need = y.shape[0]
    n_bars = max(1, int(np.ceil(need / float(bar_len))))
    ramp = int(max(0, round(smooth_ms * sr / 1000.0)))
    min_lufs_samples = int(0.4 * sr)
    
    # Calculate gain from bar 0 matching
    first_bar_gain_linear = 1.0
    
    for i in range(n_bars):
        s = i * bar_len
        e = min(need, s + bar_len)
        if e <= s: 
            break
        
        bar_samples = e - s
        tgt_bar = au.Waveform(y[s:e], sr)  # Always read from ORIGINAL

        # First bar: match to reference to establish gain
        if i == 0:
            effective_method = "rms" if bar_samples < min_lufs_samples else method
            matched_bar, stats = match_loudness_to_reference(
                ref_bar, tgt_bar, method=effective_method, headroom_db=headroom_db
            )
            
            # Calculate the linear gain that was applied
            eps = 1e-12
            first_bar_gain_linear = float(np.sqrt(
                (np.mean(matched_bar.samples**2) + eps) / 
                (np.mean(tgt_bar.samples**2) + eps)
            ))
            g = matched_bar.samples.astype(np.float32, copy=False)
        else:
            # Subsequent bars: apply the same gain from bar 0
            g = (tgt_bar.samples * first_bar_gain_linear).astype(np.float32, copy=False)

        # Calculate gain in dB for stats
        if tgt_bar.samples.size > 0:
            eps = 1e-12
            g_lin = float(np.sqrt((np.mean(g**2) + eps) / (np.mean(tgt_bar.samples**2) + eps)))
        else:
            g_lin = 1.0
        gains_db.append(20.0 * np.log10(max(g_lin, 1e-6)))

        # Apply with ramp for smoothness
        if i > 0 and ramp > 0:
            ramp_len = min(ramp, e - s)
            t = np.linspace(0.0, 1.0, ramp_len, dtype=np.float32)[:, None]
            out_adj[s:s+ramp_len] = (1.0 - t) * out_adj[s:s+ramp_len] + t * g[:ramp_len]
            out_adj[s+ramp_len:e] = g[ramp_len:e-s]
        else:
            out_adj[s:e] = g

    out.samples = out_adj.astype(np.float32, copy=False)
    return out, {"per_bar_gain_db": gains_db}