thecollabagepatch commited on
Commit
dd42331
·
1 Parent(s): 2446a8b

more cleanup in one_shot_generation to avoid stale context

Browse files
Files changed (1) hide show
  1. one_shot_generation.py +133 -107
one_shot_generation.py CHANGED
@@ -35,123 +35,149 @@ def generate_loop_continuation_with_mrt(
35
  Generate a continuation of an input loop using MagentaRT.
36
  """
37
 
38
- # ===== NEW: Force codec/model reset before generation =====
39
- # Clear any accumulated state in the codec that might cause silence issues
40
- try:
41
- # Option 1: If codec has explicit reset
42
- if hasattr(mrt.codec, 'reset') and callable(mrt.codec.reset):
43
- mrt.codec.reset()
44
-
45
- # Option 2: Force clear any cached codec state
46
- if hasattr(mrt.codec, '_encode_cache'):
47
- mrt.codec._encode_cache = None
48
- if hasattr(mrt.codec, '_decode_cache'):
49
- mrt.codec._decode_cache = None
50
-
51
- # Option 3: Clear JAX compilation caches (nuclear but effective)
52
- # Uncomment if issues persist:
53
- # import jax
54
- # jax.clear_caches()
55
-
56
- except Exception as e:
57
- import logging
58
- logging.warning(f"Codec reset attempt failed (non-fatal): {e}")
59
- # ============================================================
60
 
61
- # Load & prep (unchanged)
62
- loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
63
-
64
- # Use tail for context
65
- codec_fps = float(mrt.codec.frame_rate)
66
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
67
- loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
68
-
69
- # ===== NEW: Force fresh token copies =====
70
- tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True) # ← Added copy=True
71
- tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth].copy() # ← Added .copy()
72
- # ==========================================
73
-
74
- # Bar-aligned token window
75
- context_tokens = make_bar_aligned_context(
76
- tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
77
- ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
78
- )
79
 
80
- # ===== NEW: More aggressive state initialization =====
81
- state = mrt.init_state()
 
 
82
 
83
- # Ensure context_tokens is a fresh array, not a view
84
- state.context_tokens = np.array(context_tokens, dtype=np.int32, copy=True)
 
 
 
 
85
 
86
- # If there's any internal model state cache, clear it
87
- if hasattr(state, '_cache'):
88
- state._cache = None
89
- # =====================================================
90
-
91
- # STYLE embed (unchanged but ensure fresh embedding)
92
- loop_embed = mrt.embed_style(loop_for_context)
93
- embeds, weights = [loop_embed.copy()], [float(loop_weight)] # ← Added .copy()
94
- if extra_styles:
95
- for i, s in enumerate(extra_styles):
96
- if s.strip():
97
- embeds.append(mrt.embed_style(s.strip()).copy()) # ← Added .copy()
98
- w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
99
- weights.append(float(w))
100
- wsum = float(sum(weights)) or 1.0
101
- weights = [w / wsum for w in weights]
102
- combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype, copy=True) # ← Added copy=True
103
-
104
- # --- Length math (unchanged) ---
105
- seconds_per_bar = beats_per_bar * (60.0 / bpm)
106
- total_secs = bars * seconds_per_bar
107
- drop_bars = max(0, int(intro_bars_to_drop))
108
- drop_secs = min(drop_bars, bars) * seconds_per_bar
109
- gen_total_secs = total_secs + drop_secs
110
-
111
- chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
112
- steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
113
-
114
- if progress_cb:
115
- progress_cb(0, steps)
116
-
117
- # ===== NEW: Generation loop with explicit state refresh =====
118
- chunks = []
119
- for i in range(steps):
120
- # Generate chunk with current state
121
- wav, new_state = mrt.generate_chunk(state=state, style=combined_style)
122
- chunks.append(wav)
123
 
124
- # CRITICAL: Replace state, don't mutate it
125
- # This ensures we're not accumulating corrupted state
126
- state = new_state
127
 
128
- if progress_cb:
129
- progress_cb(i + 1, steps)
130
- # ============================================================
131
-
132
- # Rest of the function unchanged...
133
- stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
134
- stitched = hard_trim_seconds(stitched, gen_total_secs)
135
-
136
- if drop_secs > 0:
137
- n_drop = int(round(drop_secs * stitched.sample_rate))
138
- stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- out = hard_trim_seconds(stitched, total_secs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- out, loud_stats = apply_barwise_loudness_match(
143
- out=out,
144
- ref_loop=loop,
145
- bpm=bpm,
146
- beats_per_bar=beats_per_bar,
147
- method=loudness_mode,
148
- headroom_db=loudness_headroom_db,
149
- smooth_ms=50,
150
- )
151
 
152
- apply_micro_fades(out, 5)
 
153
 
154
- return out, loud_stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  def generate_style_only_with_mrt(
 
35
  Generate a continuation of an input loop using MagentaRT.
36
  """
37
 
38
+ # ===== CRITICAL FIX: Force codec state isolation =====
39
+ # Create a completely isolated encoding session to prevent
40
+ # audio from previous generations bleeding into this one
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Save original codec state (if any)
43
+ original_codec_state = {}
44
+ codec_attrs_to_clear = [
45
+ '_encode_state', '_decode_state',
46
+ '_last_encoded', '_last_decoded',
47
+ '_encoder_cache', '_decoder_cache',
48
+ '_buffer', '_frame_buffer'
49
+ ]
 
 
 
 
 
 
 
 
 
 
50
 
51
+ for attr in codec_attrs_to_clear:
52
+ if hasattr(mrt.codec, attr):
53
+ original_codec_state[attr] = getattr(mrt.codec, attr)
54
+ setattr(mrt.codec, attr, None)
55
 
56
+ # Also clear any MRT-level generation state
57
+ mrt_attrs_to_clear = ['_last_state', '_generation_cache']
58
+ for attr in mrt_attrs_to_clear:
59
+ if hasattr(mrt, attr):
60
+ original_codec_state[f'mrt_{attr}'] = getattr(mrt, attr)
61
+ setattr(mrt, attr, None)
62
 
63
+ try:
64
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Load & prep - Force FRESH file read (no caching)
67
+ loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
 
68
 
69
+ # CRITICAL: Create a detached copy to prevent reference issues
70
+ loop = au.Waveform(
71
+ loop.samples.copy(), # Force array copy
72
+ loop.sample_rate
73
+ )
74
+
75
+ # Use tail for context
76
+ codec_fps = float(mrt.codec.frame_rate)
77
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
78
+ loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
79
+
80
+ # CRITICAL: Another detached copy before encoding
81
+ loop_for_context = au.Waveform(
82
+ loop_for_context.samples.copy(),
83
+ loop_for_context.sample_rate
84
+ )
85
+
86
+ # Force fresh encoding with explicit copy flags
87
+ tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True)
88
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
89
+
90
+ # CRITICAL: Ensure tokens are not a view
91
+ tokens = np.array(tokens, dtype=np.int32, copy=True, order='C')
92
+
93
+ # Bar-aligned token window
94
+ context_tokens = make_bar_aligned_context(
95
+ tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
96
+ ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
97
+ )
98
+
99
+ # CRITICAL: Force contiguous memory layout
100
+ context_tokens = np.ascontiguousarray(context_tokens, dtype=np.int32)
101
+
102
+ # Create completely fresh state
103
+ state = mrt.init_state()
104
+ state.context_tokens = context_tokens
105
 
106
+ # STYLE embed - force fresh
107
+ loop_embed = mrt.embed_style(loop_for_context)
108
+ embeds, weights = [np.array(loop_embed, copy=True)], [float(loop_weight)]
109
+
110
+ if extra_styles:
111
+ for i, s in enumerate(extra_styles):
112
+ if s.strip():
113
+ e = mrt.embed_style(s.strip())
114
+ embeds.append(np.array(e, copy=True))
115
+ w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
116
+ weights.append(float(w))
117
+
118
+ wsum = float(sum(weights)) or 1.0
119
+ weights = [w / wsum for w in weights]
120
+ combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0)
121
+ combined_style = np.ascontiguousarray(combined_style, dtype=np.float32)
122
 
123
+ # --- Length math ---
124
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
125
+ total_secs = bars * seconds_per_bar
126
+ drop_bars = max(0, int(intro_bars_to_drop))
127
+ drop_secs = min(drop_bars, bars) * seconds_per_bar
128
+ gen_total_secs = total_secs + drop_secs
 
 
 
129
 
130
+ chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
131
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
132
 
133
+ if progress_cb:
134
+ progress_cb(0, steps)
135
+
136
+ # Generate with state isolation
137
+ chunks = []
138
+ for i in range(steps):
139
+ wav, state = mrt.generate_chunk(state=state, style=combined_style)
140
+ # Force copy the waveform samples to prevent reference issues
141
+ wav = au.Waveform(wav.samples.copy(), wav.sample_rate)
142
+ chunks.append(wav)
143
+ if progress_cb:
144
+ progress_cb(i + 1, steps)
145
+
146
+ # Rest unchanged...
147
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
148
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
149
+
150
+ if drop_secs > 0:
151
+ n_drop = int(round(drop_secs * stitched.sample_rate))
152
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
153
+
154
+ out = hard_trim_seconds(stitched, total_secs)
155
+
156
+ out, loud_stats = apply_barwise_loudness_match(
157
+ out=out,
158
+ ref_loop=loop,
159
+ bpm=bpm,
160
+ beats_per_bar=beats_per_bar,
161
+ method=loudness_mode,
162
+ headroom_db=loudness_headroom_db,
163
+ smooth_ms=50,
164
+ )
165
+
166
+ apply_micro_fades(out, 5)
167
+
168
+ return out, loud_stats
169
+
170
+ finally:
171
+ # ===== CLEANUP: Clear codec state after generation =====
172
+ # This prevents audio from THIS generation leaking into the NEXT one
173
+ for attr in codec_attrs_to_clear:
174
+ if hasattr(mrt.codec, attr):
175
+ setattr(mrt.codec, attr, None)
176
+
177
+ for attr in mrt_attrs_to_clear:
178
+ if hasattr(mrt, attr):
179
+ setattr(mrt, attr, None)
180
+ # =======================================================
181
 
182
 
183
  def generate_style_only_with_mrt(