Commit
·
7fe8be5
1
Parent(s):
4bdf506
ok reverting one more time
Browse files- jam_worker.py +355 -404
- utils.py +36 -62
jam_worker.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
# jam_worker.py -
|
| 2 |
-
import threading, time, base64, io, uuid
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
import numpy as np
|
| 5 |
import soundfile as sf
|
|
@@ -8,7 +8,7 @@ from threading import RLock
|
|
| 8 |
from utils import (
|
| 9 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
| 10 |
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
|
| 11 |
-
resample_and_snap, wav_bytes_base64
|
| 12 |
)
|
| 13 |
|
| 14 |
@dataclass
|
|
@@ -32,34 +32,6 @@ class JamChunk:
|
|
| 32 |
audio_base64: str
|
| 33 |
metadata: dict
|
| 34 |
|
| 35 |
-
@dataclass
|
| 36 |
-
class TimingState:
|
| 37 |
-
"""Precise timing state tracking"""
|
| 38 |
-
# Fractional bar position (never rounded until final emission)
|
| 39 |
-
emit_position_bars: float = 0.0
|
| 40 |
-
|
| 41 |
-
# Sample-accurate positions in the stream
|
| 42 |
-
stream_position_samples: int = 0
|
| 43 |
-
|
| 44 |
-
# Accumulated timing error for correction
|
| 45 |
-
fractional_error_bars: float = 0.0
|
| 46 |
-
|
| 47 |
-
# Codec frame timing
|
| 48 |
-
frames_per_bar: float = 0.0
|
| 49 |
-
samples_per_bar: float = 0.0
|
| 50 |
-
|
| 51 |
-
def advance_by_bars(self, bars: float):
|
| 52 |
-
"""Advance timing by exact fractional bars"""
|
| 53 |
-
self.emit_position_bars += bars
|
| 54 |
-
self.fractional_error_bars += bars - int(bars)
|
| 55 |
-
|
| 56 |
-
# Correct for accumulated error when it gets significant
|
| 57 |
-
if abs(self.fractional_error_bars) > 0.5:
|
| 58 |
-
correction = int(round(self.fractional_error_bars))
|
| 59 |
-
self.fractional_error_bars -= correction
|
| 60 |
-
return correction # bars to skip/rewind
|
| 61 |
-
return 0
|
| 62 |
-
|
| 63 |
class JamWorker(threading.Thread):
|
| 64 |
def __init__(self, mrt, params: JamParams):
|
| 65 |
super().__init__(daemon=True)
|
|
@@ -67,32 +39,9 @@ class JamWorker(threading.Thread):
|
|
| 67 |
self.params = params
|
| 68 |
self.state = mrt.init_state()
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
self._codec_fps = float(self.mrt.codec.frame_rate) # 25.0
|
| 72 |
-
self._model_sr = int(self.mrt.sample_rate) # 48000
|
| 73 |
-
self._target_sr = int(params.target_sr)
|
| 74 |
-
|
| 75 |
-
# Critical: these stay as floats to preserve fractional precision
|
| 76 |
-
self._seconds_per_bar = float(params.beats_per_bar * 60.0 / params.bpm)
|
| 77 |
-
self._frames_per_bar = self._seconds_per_bar * self._codec_fps
|
| 78 |
-
self._samples_per_bar_model = self._seconds_per_bar * self._model_sr
|
| 79 |
-
self._samples_per_bar_target = self._seconds_per_bar * self._target_sr
|
| 80 |
-
|
| 81 |
-
# Timing state
|
| 82 |
-
self._timing = TimingState(
|
| 83 |
-
frames_per_bar=self._frames_per_bar,
|
| 84 |
-
samples_per_bar=self._samples_per_bar_model
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Warn about problematic BPMs
|
| 88 |
-
frame_error = abs(self._frames_per_bar - round(self._frames_per_bar))
|
| 89 |
-
if frame_error > 0.01:
|
| 90 |
-
print(f"⚠️ Warning: {params.bpm} BPM creates {frame_error:.3f} frame drift per bar")
|
| 91 |
-
print(f" This may cause gradual timing drift in long jams")
|
| 92 |
-
|
| 93 |
-
# Synchronization + placeholders
|
| 94 |
self._lock = threading.Lock()
|
| 95 |
-
self._original_context_tokens = None
|
| 96 |
|
| 97 |
if params.combined_loop is not None:
|
| 98 |
self._setup_context_from_combined_loop()
|
|
@@ -101,39 +50,28 @@ class JamWorker(threading.Thread):
|
|
| 101 |
self.outbox: list[JamChunk] = []
|
| 102 |
self._stop_event = threading.Event()
|
| 103 |
|
| 104 |
-
# Stream state
|
| 105 |
self._stream = None
|
| 106 |
-
self.
|
| 107 |
-
|
| 108 |
-
#
|
| 109 |
self._last_delivered_index = 0
|
| 110 |
self._max_buffer_ahead = 5
|
| 111 |
|
| 112 |
-
# Streaming resampler for precise SR conversion
|
| 113 |
-
self._resampler = None
|
| 114 |
-
if self._target_sr != self._model_sr:
|
| 115 |
-
self._resampler = StreamingResampler(
|
| 116 |
-
in_sr=self._model_sr,
|
| 117 |
-
out_sr=self._target_sr,
|
| 118 |
-
channels=2,
|
| 119 |
-
quality="VHQ"
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
# Timing info
|
| 123 |
self.last_chunk_started_at = None
|
| 124 |
self.last_chunk_completed_at = None
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
self.
|
| 128 |
-
self.
|
| 129 |
-
|
| 130 |
|
| 131 |
def _setup_context_from_combined_loop(self):
|
| 132 |
"""Set up MRT context tokens from the combined loop audio"""
|
| 133 |
try:
|
| 134 |
from utils import make_bar_aligned_context, take_bar_aligned_tail
|
| 135 |
|
| 136 |
-
codec_fps = self.
|
| 137 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
| 138 |
|
| 139 |
loop_for_context = take_bar_aligned_tail(
|
|
@@ -146,381 +84,452 @@ class JamWorker(threading.Thread):
|
|
| 146 |
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
|
| 147 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 148 |
|
| 149 |
-
# Use enhanced context alignment for fractional BPMs
|
| 150 |
context_tokens = make_bar_aligned_context(
|
| 151 |
tokens,
|
| 152 |
bpm=self.params.bpm,
|
| 153 |
-
fps=self.
|
| 154 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 155 |
-
beats_per_bar=self.params.beats_per_bar
|
| 156 |
-
precise_timing=True # Use new precise mode
|
| 157 |
)
|
| 158 |
|
|
|
|
| 159 |
self.state.context_tokens = context_tokens
|
| 160 |
-
print(f"
|
| 161 |
|
| 162 |
-
#
|
|
|
|
| 163 |
with self._lock:
|
| 164 |
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
| 165 |
-
self._original_context_tokens = np.copy(context_tokens)
|
| 166 |
|
| 167 |
except Exception as e:
|
| 168 |
-
print(f"Failed to setup context from combined loop: {e}")
|
| 169 |
|
| 170 |
def stop(self):
|
| 171 |
self._stop_event.set()
|
| 172 |
|
| 173 |
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
| 174 |
with self._lock:
|
| 175 |
-
if guidance_weight is not None:
|
| 176 |
-
|
| 177 |
-
if
|
| 178 |
-
self.params.temperature = float(temperature)
|
| 179 |
-
if topk is not None:
|
| 180 |
-
self.params.topk = int(topk)
|
| 181 |
|
| 182 |
def get_next_chunk(self) -> JamChunk | None:
|
| 183 |
"""Get the next sequential chunk (blocks/waits if not ready)"""
|
| 184 |
target_index = self._last_delivered_index + 1
|
| 185 |
|
| 186 |
-
|
|
|
|
| 187 |
start_time = time.time()
|
| 188 |
|
| 189 |
while time.time() - start_time < max_wait and not self._stop_event.is_set():
|
| 190 |
with self._lock:
|
|
|
|
| 191 |
for chunk in self.outbox:
|
| 192 |
if chunk.index == target_index:
|
| 193 |
self._last_delivered_index = target_index
|
| 194 |
-
print(f"Delivered chunk {target_index}
|
| 195 |
return chunk
|
|
|
|
|
|
|
| 196 |
time.sleep(0.1)
|
| 197 |
|
|
|
|
| 198 |
return None
|
| 199 |
|
| 200 |
def mark_chunk_consumed(self, chunk_index: int):
|
| 201 |
"""Mark a chunk as consumed by the frontend"""
|
| 202 |
with self._lock:
|
| 203 |
self._last_delivered_index = max(self._last_delivered_index, chunk_index)
|
|
|
|
| 204 |
|
| 205 |
def _should_generate_next_chunk(self) -> bool:
|
| 206 |
-
"""Check if we should generate the next chunk"""
|
| 207 |
with self._lock:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
def _append_model_chunk_to_stream(self, wav):
|
| 216 |
-
"""
|
| 217 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 218 |
-
sr = self.
|
| 219 |
xfade_n = int(round(xfade_s * sr))
|
| 220 |
|
| 221 |
s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
| 222 |
|
| 223 |
-
if self
|
| 224 |
-
# First chunk: drop model pre-roll
|
| 225 |
if s.shape[0] > xfade_n:
|
| 226 |
self._stream = s[xfade_n:].astype(np.float32, copy=True)
|
| 227 |
else:
|
| 228 |
self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
|
| 229 |
-
self.
|
| 230 |
return
|
| 231 |
|
| 232 |
-
# Crossfade with
|
| 233 |
if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
|
| 234 |
-
# Degenerate
|
| 235 |
self._stream = np.concatenate([self._stream, s], axis=0)
|
| 236 |
-
self._stream_write_pos = self._stream.shape[0]
|
| 237 |
return
|
| 238 |
|
| 239 |
-
# Standard crossfade
|
| 240 |
tail = self._stream[-xfade_n:]
|
| 241 |
head = s[:xfade_n]
|
| 242 |
|
|
|
|
| 243 |
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
| 244 |
eq_in, eq_out = np.sin(t), np.cos(t)
|
| 245 |
mixed = tail * eq_out + head * eq_in
|
| 246 |
|
| 247 |
self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
|
| 248 |
-
self._stream_write_pos = self._stream.shape[0]
|
| 249 |
-
|
| 250 |
-
def _extract_precise_chunk(self, start_bars: float, chunk_bars: float) -> np.ndarray:
|
| 251 |
-
"""Extract exactly chunk_bars worth of audio starting at start_bars"""
|
| 252 |
-
start_samples = self._get_precise_chunk_samples(start_bars)
|
| 253 |
-
chunk_samples = self._get_precise_chunk_samples(chunk_bars)
|
| 254 |
-
end_samples = start_samples + chunk_samples
|
| 255 |
-
|
| 256 |
-
if end_samples > self._stream.shape[0]:
|
| 257 |
-
return None # Not enough audio generated yet
|
| 258 |
-
|
| 259 |
-
return self._stream[start_samples:end_samples]
|
| 260 |
-
|
| 261 |
-
def _perform_onset_alignment(self, ref_loop: au.Waveform) -> float:
|
| 262 |
-
"""Estimate timing offset between generated audio and reference"""
|
| 263 |
-
if self._stream is None or self._stream.shape[0] < self._model_sr:
|
| 264 |
-
return 0.0
|
| 265 |
-
|
| 266 |
-
try:
|
| 267 |
-
# Take first ~2 seconds of generated audio
|
| 268 |
-
gen_samples = min(int(2.0 * self._model_sr), self._stream.shape[0])
|
| 269 |
-
gen_head = au.Waveform(
|
| 270 |
-
self._stream[:gen_samples].astype(np.float32, copy=False),
|
| 271 |
-
self._model_sr
|
| 272 |
-
).as_stereo()
|
| 273 |
-
|
| 274 |
-
# Reference: last bar of the loop
|
| 275 |
-
ref_samples = int(self._seconds_per_bar * ref_loop.sample_rate)
|
| 276 |
-
if ref_loop.samples.shape[0] >= ref_samples:
|
| 277 |
-
ref_tail = au.Waveform(
|
| 278 |
-
ref_loop.samples[-ref_samples:],
|
| 279 |
-
ref_loop.sample_rate
|
| 280 |
-
).resample(self._model_sr).as_stereo()
|
| 281 |
-
else:
|
| 282 |
-
ref_tail = ref_loop.resample(self._model_sr).as_stereo()
|
| 283 |
-
|
| 284 |
-
# Cross-correlation based alignment
|
| 285 |
-
def envelope(x, sr):
|
| 286 |
-
if x.ndim == 2:
|
| 287 |
-
x = x.mean(axis=1)
|
| 288 |
-
x = np.abs(x).astype(np.float32)
|
| 289 |
-
# Simple smoothing
|
| 290 |
-
win = max(1, int(0.01 * sr)) # 10ms window
|
| 291 |
-
if win > 1:
|
| 292 |
-
kernel = np.ones(win) / win
|
| 293 |
-
x = np.convolve(x, kernel, mode='same')
|
| 294 |
-
return x
|
| 295 |
-
|
| 296 |
-
env_ref = envelope(ref_tail.samples, self._model_sr)
|
| 297 |
-
env_gen = envelope(gen_head.samples, self._model_sr)
|
| 298 |
-
|
| 299 |
-
# Limit search range to reasonable offset
|
| 300 |
-
max_offset_samples = int(0.2 * self._model_sr) # 200ms max
|
| 301 |
-
|
| 302 |
-
# Normalize for correlation
|
| 303 |
-
env_ref = (env_ref - env_ref.mean()) / (env_ref.std() + 1e-8)
|
| 304 |
-
env_gen = (env_gen - env_gen.mean()) / (env_gen.std() + 1e-8)
|
| 305 |
-
|
| 306 |
-
# Find best correlation
|
| 307 |
-
best_offset = 0
|
| 308 |
-
best_corr = -1.0
|
| 309 |
-
|
| 310 |
-
search_len = min(len(env_ref), len(env_gen) - max_offset_samples)
|
| 311 |
-
if search_len > 0:
|
| 312 |
-
for offset in range(0, max_offset_samples, 4): # subsample for speed
|
| 313 |
-
if offset + search_len >= len(env_gen):
|
| 314 |
-
break
|
| 315 |
-
corr = np.corrcoef(env_ref[:search_len], env_gen[offset:offset+search_len])[0,1]
|
| 316 |
-
if not np.isnan(corr) and corr > best_corr:
|
| 317 |
-
best_corr = corr
|
| 318 |
-
best_offset = offset
|
| 319 |
-
|
| 320 |
-
offset_seconds = best_offset / self._model_sr
|
| 321 |
-
print(f"Onset alignment: {offset_seconds:.3f}s offset (correlation: {best_corr:.3f})")
|
| 322 |
-
return offset_seconds
|
| 323 |
-
|
| 324 |
-
except Exception as e:
|
| 325 |
-
print(f"Onset alignment failed: {e}")
|
| 326 |
-
return 0.0
|
| 327 |
-
|
| 328 |
-
def _align_to_bar_boundary(self):
|
| 329 |
-
"""Align timing state to next bar boundary"""
|
| 330 |
-
current_bar = self._timing.emit_position_bars
|
| 331 |
-
next_bar = math.ceil(current_bar)
|
| 332 |
-
|
| 333 |
-
if abs(next_bar - current_bar) > 1e-6:
|
| 334 |
-
skip_bars = next_bar - current_bar
|
| 335 |
-
skip_samples = self._get_precise_chunk_samples(skip_bars)
|
| 336 |
-
self._timing.stream_position_samples += skip_samples
|
| 337 |
-
self._timing.emit_position_bars = next_bar
|
| 338 |
-
print(f"Aligned to bar {next_bar:.0f}, skipped {skip_bars:.4f} bars")
|
| 339 |
|
| 340 |
def reseed_from_waveform(self, wav):
|
| 341 |
-
|
| 342 |
new_state = self.mrt.init_state()
|
| 343 |
-
|
| 344 |
-
# Build
|
| 345 |
-
codec_fps
|
| 346 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
| 347 |
-
|
|
|
|
| 348 |
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
|
| 349 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 350 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
tokens,
|
| 354 |
-
bpm=self.params.bpm,
|
| 355 |
-
fps=self._codec_fps,
|
| 356 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 357 |
-
beats_per_bar=self.params.beats_per_bar
|
| 358 |
-
precise_timing=True
|
| 359 |
)
|
| 360 |
-
|
| 361 |
new_state.context_tokens = context_tokens
|
| 362 |
self.state = new_state
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
self._needs_bar_realign = True
|
| 372 |
-
self._reseed_ref_loop = wav
|
| 373 |
|
| 374 |
def reseed_splice(self, recent_wav, anchor_bars: float):
|
| 375 |
-
"""
|
|
|
|
|
|
|
| 376 |
with self._lock:
|
| 377 |
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
| 378 |
self._original_context_tokens = np.copy(self.state.context_tokens)
|
| 379 |
|
| 380 |
-
|
| 381 |
-
recent_tokens = self._make_recent_tokens_from_wave(recent_wav)
|
| 382 |
new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
|
| 383 |
|
|
|
|
| 384 |
self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
|
| 385 |
-
|
| 386 |
-
#
|
| 387 |
new_state = self.mrt.init_state()
|
| 388 |
new_state.context_tokens = new_ctx
|
| 389 |
self.state = new_state
|
| 390 |
|
| 391 |
-
|
| 392 |
-
self._stream = None
|
| 393 |
-
self._stream_write_pos = 0
|
| 394 |
-
self._timing = TimingState(
|
| 395 |
-
frames_per_bar=self._frames_per_bar,
|
| 396 |
-
samples_per_bar=self._samples_per_bar_model
|
| 397 |
-
)
|
| 398 |
-
self._needs_bar_realign = True
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
tokens_full = self.mrt.codec.encode(wav).astype(np.int32)
|
| 403 |
-
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 404 |
-
|
| 405 |
-
context_tokens = make_bar_aligned_context(
|
| 406 |
-
tokens,
|
| 407 |
-
bpm=self.params.bpm,
|
| 408 |
-
fps=self._codec_fps,
|
| 409 |
-
ctx_frames=self.mrt.config.context_length_frames,
|
| 410 |
-
beats_per_bar=self.params.beats_per_bar,
|
| 411 |
-
precise_timing=True
|
| 412 |
-
)
|
| 413 |
-
return context_tokens
|
| 414 |
-
|
| 415 |
-
def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray, anchor_bars: float) -> np.ndarray:
|
| 416 |
-
"""Enhanced context splicing with fractional bar handling"""
|
| 417 |
-
ctx_frames = int(self.mrt.config.context_length_frames)
|
| 418 |
-
|
| 419 |
-
# Convert anchor bars to codec frames (keep fractional precision)
|
| 420 |
-
anchor_frames_f = anchor_bars * self._frames_per_bar
|
| 421 |
-
anchor_frames = int(round(anchor_frames_f))
|
| 422 |
-
|
| 423 |
-
# Take anchor from original
|
| 424 |
-
anchor = original_tokens[-anchor_frames:] if anchor_frames <= original_tokens.shape[0] else original_tokens
|
| 425 |
-
|
| 426 |
-
# Fill remainder with recent tokens
|
| 427 |
-
remain_frames = ctx_frames - anchor.shape[0]
|
| 428 |
-
if remain_frames > 0:
|
| 429 |
-
recent = recent_tokens[-remain_frames:] if remain_frames <= recent_tokens.shape[0] else recent_tokens
|
| 430 |
-
else:
|
| 431 |
-
recent = recent_tokens[:0] # empty
|
| 432 |
-
|
| 433 |
-
# Combine
|
| 434 |
-
if anchor.size > 0 and recent.size > 0:
|
| 435 |
-
spliced = np.concatenate([recent, anchor], axis=0)
|
| 436 |
-
elif anchor.size > 0:
|
| 437 |
-
spliced = anchor
|
| 438 |
-
else:
|
| 439 |
-
spliced = recent_tokens[-ctx_frames:]
|
| 440 |
-
|
| 441 |
-
# Ensure exact length
|
| 442 |
-
if spliced.shape[0] > ctx_frames:
|
| 443 |
-
spliced = spliced[-ctx_frames:]
|
| 444 |
-
elif spliced.shape[0] < ctx_frames:
|
| 445 |
-
# Tile to fill
|
| 446 |
-
reps = int(np.ceil(ctx_frames / max(1, spliced.shape[0])))
|
| 447 |
-
spliced = np.tile(spliced, (reps, 1))[-ctx_frames:]
|
| 448 |
-
|
| 449 |
-
return spliced
|
| 450 |
|
| 451 |
def run(self):
|
| 452 |
-
"""Main
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
if first_chunk_extra:
|
| 464 |
-
#
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
while not self._stop_event.is_set():
|
| 472 |
if not self._should_generate_next_chunk():
|
| 473 |
time.sleep(0.25)
|
| 474 |
continue
|
| 475 |
|
| 476 |
-
# 1) Generate until we have enough
|
| 477 |
-
|
| 478 |
-
while
|
| 479 |
with self._lock:
|
| 480 |
style_vec = self.params.style_vec
|
| 481 |
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
| 482 |
-
self.mrt.temperature
|
| 483 |
-
self.mrt.topk
|
| 484 |
-
|
| 485 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 486 |
-
self._append_model_chunk_to_stream(wav)
|
| 487 |
-
|
| 488 |
|
| 489 |
if self._stop_event.is_set():
|
| 490 |
break
|
| 491 |
|
| 492 |
-
# 2)
|
| 493 |
if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
|
| 494 |
ref_loop = self._reseed_ref_loop or self.params.combined_loop
|
| 495 |
if ref_loop is not None:
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
|
|
|
| 503 |
self._needs_bar_realign = False
|
| 504 |
self._reseed_ref_loop = None
|
| 505 |
|
| 506 |
-
# 3)
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
continue
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
if correction != 0:
|
| 516 |
-
print(f"Applied {correction} bar timing correction")
|
| 517 |
-
|
| 518 |
-
self._timing.stream_position_samples += chunk_samples
|
| 519 |
|
| 520 |
-
|
| 521 |
-
y = au.Waveform(slice_audio.astype(np.float32, copy=False), self._model_sr).as_stereo()
|
| 522 |
|
| 523 |
-
#
|
| 524 |
if self.idx == 0 and self.params.ref_loop is not None:
|
| 525 |
y, _ = match_loudness_to_reference(
|
| 526 |
self.params.ref_loop, y,
|
|
@@ -530,96 +539,38 @@ class JamWorker(threading.Thread):
|
|
| 530 |
else:
|
| 531 |
apply_micro_fades(y, 3)
|
| 532 |
|
| 533 |
-
# 5)
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
# Ensure exact target length
|
| 539 |
-
target_samples = int(round(chunk_bars * self._samples_per_bar_target))
|
| 540 |
-
if resampled.shape[0] != target_samples:
|
| 541 |
-
if resampled.shape[0] < target_samples:
|
| 542 |
-
pad_samples = target_samples - resampled.shape[0]
|
| 543 |
-
pad = np.zeros((pad_samples, resampled.shape[1]), dtype=resampled.dtype)
|
| 544 |
-
resampled = np.vstack([resampled, pad])
|
| 545 |
-
else:
|
| 546 |
-
resampled = resampled[:target_samples]
|
| 547 |
-
|
| 548 |
-
final_audio = resampled
|
| 549 |
-
final_sr = self._target_sr
|
| 550 |
-
else:
|
| 551 |
-
# No resampling needed
|
| 552 |
-
final_audio = y.samples
|
| 553 |
-
final_sr = self._model_sr
|
| 554 |
|
| 555 |
-
# 6)
|
| 556 |
-
b64, total_samples, channels = wav_bytes_base64(final_audio, final_sr)
|
| 557 |
-
|
| 558 |
-
# 7) Create metadata with timing info
|
| 559 |
-
actual_duration = total_samples / final_sr
|
| 560 |
-
bar_range = f"{chunk_start_bars:.2f}-{self._timing.emit_position_bars:.2f}"
|
| 561 |
-
|
| 562 |
-
meta = {
|
| 563 |
-
"bpm": int(round(self.params.bpm)),
|
| 564 |
-
"bars": int(self.params.bars_per_chunk),
|
| 565 |
-
"beats_per_bar": int(self.params.beats_per_bar),
|
| 566 |
-
"sample_rate": int(final_sr),
|
| 567 |
-
"channels": int(channels),
|
| 568 |
-
"total_samples": int(total_samples),
|
| 569 |
-
"seconds_per_bar": self._seconds_per_bar,
|
| 570 |
-
"loop_duration_seconds": actual_duration,
|
| 571 |
-
"bar_range": bar_range,
|
| 572 |
-
"timing_state": {
|
| 573 |
-
"emit_position_bars": self._timing.emit_position_bars,
|
| 574 |
-
"frames_per_bar": self._frames_per_bar,
|
| 575 |
-
"fractional_error": self._timing.fractional_error_bars,
|
| 576 |
-
},
|
| 577 |
-
"xfade_seconds": xfade_s,
|
| 578 |
-
"guidance_weight": self.params.guidance_weight,
|
| 579 |
-
"temperature": self.params.temperature,
|
| 580 |
-
"topk": self.params.topk,
|
| 581 |
-
}
|
| 582 |
-
|
| 583 |
-
# 8) Publish chunk
|
| 584 |
with self._lock:
|
| 585 |
self.idx += 1
|
| 586 |
-
|
| 587 |
-
self.outbox.append(chunk)
|
| 588 |
-
|
| 589 |
-
# Cleanup old chunks
|
| 590 |
if len(self.outbox) > 10:
|
| 591 |
cutoff = self._last_delivered_index - 5
|
| 592 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
| 593 |
|
| 594 |
-
#
|
| 595 |
if self._pending_reseed is not None:
|
| 596 |
pkg = self._pending_reseed
|
| 597 |
self._pending_reseed = None
|
| 598 |
|
| 599 |
new_state = self.mrt.init_state()
|
| 600 |
-
new_state.context_tokens = pkg["ctx"]
|
| 601 |
self.state = new_state
|
| 602 |
|
| 603 |
-
#
|
| 604 |
self._stream = None
|
| 605 |
-
self.
|
| 606 |
-
self.
|
| 607 |
-
frames_per_bar=self._frames_per_bar,
|
| 608 |
-
samples_per_bar=self._samples_per_bar_model
|
| 609 |
-
)
|
| 610 |
-
self._reseed_ref_loop = pkg.get("ref")
|
| 611 |
self._needs_bar_realign = True
|
| 612 |
|
| 613 |
-
print("Reseed
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
|
|
|
| 617 |
|
| 618 |
-
print("JamWorker stopped")
|
| 619 |
-
|
| 620 |
-
# Clean up resampler
|
| 621 |
-
if self._resampler is not None:
|
| 622 |
-
try:
|
| 623 |
-
self._resampler.flush()
|
| 624 |
-
except:
|
| 625 |
-
pass
|
|
|
|
| 1 |
+
# jam_worker.py - SIMPLE FIX VERSION
|
| 2 |
+
import threading, time, base64, io, uuid
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
import numpy as np
|
| 5 |
import soundfile as sf
|
|
|
|
| 8 |
from utils import (
|
| 9 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
| 10 |
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
|
| 11 |
+
resample_and_snap, wav_bytes_base64
|
| 12 |
)
|
| 13 |
|
| 14 |
@dataclass
|
|
|
|
| 32 |
audio_base64: str
|
| 33 |
metadata: dict
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
class JamWorker(threading.Thread):
|
| 36 |
def __init__(self, mrt, params: JamParams):
|
| 37 |
super().__init__(daemon=True)
|
|
|
|
| 39 |
self.params = params
|
| 40 |
self.state = mrt.init_state()
|
| 41 |
|
| 42 |
+
# ✅ init synchronization + placeholders FIRST
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
self._lock = threading.Lock()
|
| 44 |
+
self._original_context_tokens = None # so hasattr checks are cheap/clear
|
| 45 |
|
| 46 |
if params.combined_loop is not None:
|
| 47 |
self._setup_context_from_combined_loop()
|
|
|
|
| 50 |
self.outbox: list[JamChunk] = []
|
| 51 |
self._stop_event = threading.Event()
|
| 52 |
|
|
|
|
| 53 |
self._stream = None
|
| 54 |
+
self._next_emit_start = 0
|
| 55 |
+
|
| 56 |
+
# NEW: Track delivery state
|
| 57 |
self._last_delivered_index = 0
|
| 58 |
self._max_buffer_ahead = 5
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Timing info
|
| 61 |
self.last_chunk_started_at = None
|
| 62 |
self.last_chunk_completed_at = None
|
| 63 |
|
| 64 |
+
self._pending_reseed = None # {"ctx": np.ndarray, "ref": au.Waveform|None}
|
| 65 |
+
self._needs_bar_realign = False # request a one-shot downbeat alignment
|
| 66 |
+
self._reseed_ref_loop = None # which loop to align against after reseed
|
| 67 |
+
|
| 68 |
|
| 69 |
def _setup_context_from_combined_loop(self):
|
| 70 |
"""Set up MRT context tokens from the combined loop audio"""
|
| 71 |
try:
|
| 72 |
from utils import make_bar_aligned_context, take_bar_aligned_tail
|
| 73 |
|
| 74 |
+
codec_fps = float(self.mrt.codec.frame_rate)
|
| 75 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
| 76 |
|
| 77 |
loop_for_context = take_bar_aligned_tail(
|
|
|
|
| 84 |
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
|
| 85 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 86 |
|
|
|
|
| 87 |
context_tokens = make_bar_aligned_context(
|
| 88 |
tokens,
|
| 89 |
bpm=self.params.bpm,
|
| 90 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
| 91 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 92 |
+
beats_per_bar=self.params.beats_per_bar
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
+
# Install fresh context
|
| 96 |
self.state.context_tokens = context_tokens
|
| 97 |
+
print(f"✅ JamWorker: Set up fresh context from combined loop")
|
| 98 |
|
| 99 |
+
# NEW: keep a copy of the *original* context tokens for future splice-reseed
|
| 100 |
+
# (guard so we only set this once, at jam start)
|
| 101 |
with self._lock:
|
| 102 |
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
| 103 |
+
self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth]
|
| 104 |
|
| 105 |
except Exception as e:
|
| 106 |
+
print(f"❌ Failed to setup context from combined loop: {e}")
|
| 107 |
|
| 108 |
def stop(self):
|
| 109 |
self._stop_event.set()
|
| 110 |
|
| 111 |
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
| 112 |
with self._lock:
|
| 113 |
+
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
|
| 114 |
+
if temperature is not None: self.params.temperature = float(temperature)
|
| 115 |
+
if topk is not None: self.params.topk = int(topk)
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
def get_next_chunk(self) -> JamChunk | None:
|
| 118 |
"""Get the next sequential chunk (blocks/waits if not ready)"""
|
| 119 |
target_index = self._last_delivered_index + 1
|
| 120 |
|
| 121 |
+
# Wait for the target chunk to be ready (with timeout)
|
| 122 |
+
max_wait = 30.0 # seconds
|
| 123 |
start_time = time.time()
|
| 124 |
|
| 125 |
while time.time() - start_time < max_wait and not self._stop_event.is_set():
|
| 126 |
with self._lock:
|
| 127 |
+
# Look for the exact chunk we need
|
| 128 |
for chunk in self.outbox:
|
| 129 |
if chunk.index == target_index:
|
| 130 |
self._last_delivered_index = target_index
|
| 131 |
+
print(f"📦 Delivered chunk {target_index}")
|
| 132 |
return chunk
|
| 133 |
+
|
| 134 |
+
# Not ready yet, wait a bit
|
| 135 |
time.sleep(0.1)
|
| 136 |
|
| 137 |
+
# Timeout or stopped
|
| 138 |
return None
|
| 139 |
|
| 140 |
def mark_chunk_consumed(self, chunk_index: int):
|
| 141 |
"""Mark a chunk as consumed by the frontend"""
|
| 142 |
with self._lock:
|
| 143 |
self._last_delivered_index = max(self._last_delivered_index, chunk_index)
|
| 144 |
+
print(f"✅ Chunk {chunk_index} consumed")
|
| 145 |
|
| 146 |
def _should_generate_next_chunk(self) -> bool:
|
| 147 |
+
"""Check if we should generate the next chunk (don't get too far ahead)"""
|
| 148 |
with self._lock:
|
| 149 |
+
# Don't generate if we're already too far ahead
|
| 150 |
+
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
| 151 |
+
return False
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
def _seconds_per_bar(self) -> float:
|
| 155 |
+
return self.params.beats_per_bar * (60.0 / self.params.bpm)
|
| 156 |
+
|
| 157 |
+
def _snap_and_encode(self, y, seconds, target_sr, bars):
|
| 158 |
+
cur_sr = int(self.mrt.sample_rate)
|
| 159 |
+
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
|
| 160 |
+
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
|
| 161 |
+
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
|
| 162 |
+
meta = {
|
| 163 |
+
"bpm": int(round(self.params.bpm)),
|
| 164 |
+
"bars": int(bars),
|
| 165 |
+
"beats_per_bar": int(self.params.beats_per_bar),
|
| 166 |
+
"sample_rate": int(target_sr),
|
| 167 |
+
"channels": channels,
|
| 168 |
+
"total_samples": total_samples,
|
| 169 |
+
"seconds_per_bar": self._seconds_per_bar(),
|
| 170 |
+
"loop_duration_seconds": bars * self._seconds_per_bar(),
|
| 171 |
+
"guidance_weight": self.params.guidance_weight,
|
| 172 |
+
"temperature": self.params.temperature,
|
| 173 |
+
"topk": self.params.topk,
|
| 174 |
+
}
|
| 175 |
+
return b64, meta
|
| 176 |
|
| 177 |
def _append_model_chunk_to_stream(self, wav):
|
| 178 |
+
"""Incrementally append a model chunk with equal-power crossfade."""
|
| 179 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 180 |
+
sr = int(self.mrt.sample_rate)
|
| 181 |
xfade_n = int(round(xfade_s * sr))
|
| 182 |
|
| 183 |
s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
| 184 |
|
| 185 |
+
if getattr(self, "_stream", None) is None:
|
| 186 |
+
# First chunk: drop model pre-roll (xfade head)
|
| 187 |
if s.shape[0] > xfade_n:
|
| 188 |
self._stream = s[xfade_n:].astype(np.float32, copy=True)
|
| 189 |
else:
|
| 190 |
self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
|
| 191 |
+
self._next_emit_start = 0 # pointer into _stream (model SR samples)
|
| 192 |
return
|
| 193 |
|
| 194 |
+
# Crossfade last xfade_n samples of _stream with head of new s
|
| 195 |
if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
|
| 196 |
+
# Degenerate safeguard
|
| 197 |
self._stream = np.concatenate([self._stream, s], axis=0)
|
|
|
|
| 198 |
return
|
| 199 |
|
|
|
|
| 200 |
tail = self._stream[-xfade_n:]
|
| 201 |
head = s[:xfade_n]
|
| 202 |
|
| 203 |
+
# Equal-power envelopes
|
| 204 |
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
| 205 |
eq_in, eq_out = np.sin(t), np.cos(t)
|
| 206 |
mixed = tail * eq_out + head * eq_in
|
| 207 |
|
| 208 |
self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
def reseed_from_waveform(self, wav):
|
| 211 |
+
# 1) Re-init state
|
| 212 |
new_state = self.mrt.init_state()
|
| 213 |
+
|
| 214 |
+
# 2) Build bar-aligned context tokens from provided audio
|
| 215 |
+
codec_fps = float(self.mrt.codec.frame_rate)
|
| 216 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
| 217 |
+
from utils import take_bar_aligned_tail, make_bar_aligned_context
|
| 218 |
+
|
| 219 |
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
|
| 220 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 221 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 222 |
+
context_tokens = make_bar_aligned_context(tokens,
|
| 223 |
+
bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
|
|
|
|
|
|
|
|
|
|
| 224 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 225 |
+
beats_per_bar=self.params.beats_per_bar
|
|
|
|
| 226 |
)
|
|
|
|
| 227 |
new_state.context_tokens = context_tokens
|
| 228 |
self.state = new_state
|
| 229 |
+
self._prepare_stream_for_reseed_handoff()
|
| 230 |
+
|
| 231 |
+
def _frames_per_bar(self) -> int:
|
| 232 |
+
# codec frame-rate (frames/s) -> frames per musical bar
|
| 233 |
+
fps = float(self.mrt.codec.frame_rate)
|
| 234 |
+
sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar)
|
| 235 |
+
return int(round(fps * sec_per_bar))
|
| 236 |
+
|
| 237 |
+
def _ctx_frames(self) -> int:
|
| 238 |
+
# how many codec frames fit in the model’s conditioning window
|
| 239 |
+
return int(self.mrt.config.context_length_frames)
|
| 240 |
+
|
| 241 |
+
def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
|
| 242 |
+
"""
|
| 243 |
+
Encode waveform and produce a BAR-ALIGNED context token window.
|
| 244 |
+
"""
|
| 245 |
+
tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
|
| 246 |
+
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 247 |
+
|
| 248 |
+
from utils import make_bar_aligned_context
|
| 249 |
+
ctx = make_bar_aligned_context(
|
| 250 |
+
tokens,
|
| 251 |
+
bpm=self.params.bpm,
|
| 252 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
| 253 |
+
ctx_frames=self.mrt.config.context_length_frames,
|
| 254 |
+
beats_per_bar=self.params.beats_per_bar
|
| 255 |
)
|
| 256 |
+
return ctx
|
| 257 |
+
|
| 258 |
+
def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
|
| 259 |
+
"""
|
| 260 |
+
Take a tail slice that is an integer number of codec frames corresponding to `bars`.
|
| 261 |
+
We round to nearest frame to stay phase-consistent with codec grid.
|
| 262 |
+
"""
|
| 263 |
+
frames_per_bar = self._frames_per_bar()
|
| 264 |
+
want = max(frames_per_bar * int(round(bars)), 0)
|
| 265 |
+
if want == 0:
|
| 266 |
+
return tokens[:0] # empty
|
| 267 |
+
if tokens.shape[0] <= want:
|
| 268 |
+
return tokens
|
| 269 |
+
return tokens[-want:]
|
| 270 |
+
|
| 271 |
+
def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
|
| 272 |
+
anchor_bars: float) -> np.ndarray:
|
| 273 |
+
import math
|
| 274 |
+
ctx_frames = self._ctx_frames()
|
| 275 |
+
depth = original_tokens.shape[1]
|
| 276 |
+
frames_per_bar = self._frames_per_bar()
|
| 277 |
+
|
| 278 |
+
# 1) Anchor tail (whole bars)
|
| 279 |
+
anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
|
| 280 |
+
|
| 281 |
+
# 2) Fill remainder with recent (prefer whole bars)
|
| 282 |
+
a = anchor.shape[0]
|
| 283 |
+
remain = max(ctx_frames - a, 0)
|
| 284 |
+
|
| 285 |
+
recent = recent_tokens[:0]
|
| 286 |
+
used_recent = 0 # frames taken from the END of recent_tokens
|
| 287 |
+
if remain > 0:
|
| 288 |
+
bars_fit = remain // frames_per_bar
|
| 289 |
+
if bars_fit >= 1:
|
| 290 |
+
want_recent_frames = int(bars_fit * frames_per_bar)
|
| 291 |
+
used_recent = min(want_recent_frames, recent_tokens.shape[0])
|
| 292 |
+
recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
|
| 293 |
+
else:
|
| 294 |
+
used_recent = min(remain, recent_tokens.shape[0])
|
| 295 |
+
recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
|
| 296 |
+
|
| 297 |
+
# 3) Concat in order [anchor, recent]
|
| 298 |
+
if anchor.size or recent.size:
|
| 299 |
+
out = np.concatenate([anchor, recent], axis=0)
|
| 300 |
+
else:
|
| 301 |
+
# fallback: just take the last ctx window from recent
|
| 302 |
+
out = recent_tokens[-ctx_frames:]
|
| 303 |
+
|
| 304 |
+
# 4) Trim if we overshot
|
| 305 |
+
if out.shape[0] > ctx_frames:
|
| 306 |
+
out = out[-ctx_frames:]
|
| 307 |
+
|
| 308 |
+
# 5) Snap the **END** to the nearest LOWER bar boundary
|
| 309 |
+
if frames_per_bar > 0:
|
| 310 |
+
max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
|
| 311 |
+
else:
|
| 312 |
+
max_bar_aligned = out.shape[0]
|
| 313 |
+
if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
|
| 314 |
+
out = out[-max_bar_aligned:]
|
| 315 |
+
|
| 316 |
+
# 6) Left-fill to reach ctx_frames **without moving the END**
|
| 317 |
+
deficit = ctx_frames - out.shape[0]
|
| 318 |
+
if deficit > 0:
|
| 319 |
+
left_parts = []
|
| 320 |
+
|
| 321 |
+
# Prefer frames immediately BEFORE the region we used from 'recent_tokens'
|
| 322 |
+
if used_recent < recent_tokens.shape[0]:
|
| 323 |
+
take = min(deficit, recent_tokens.shape[0] - used_recent)
|
| 324 |
+
if used_recent > 0:
|
| 325 |
+
left_parts.append(recent_tokens[-(used_recent + take) : -used_recent])
|
| 326 |
+
else:
|
| 327 |
+
left_parts.append(recent_tokens[-take:])
|
| 328 |
+
|
| 329 |
+
# Then take frames immediately BEFORE the 'anchor' in original_tokens
|
| 330 |
+
if sum(p.shape[0] for p in left_parts) < deficit and anchor.shape[0] > 0:
|
| 331 |
+
need = deficit - sum(p.shape[0] for p in left_parts)
|
| 332 |
+
a_len = anchor.shape[0]
|
| 333 |
+
avail = max(original_tokens.shape[0] - a_len, 0)
|
| 334 |
+
take2 = min(need, avail)
|
| 335 |
+
if take2 > 0:
|
| 336 |
+
left_parts.append(original_tokens[-(a_len + take2) : -a_len])
|
| 337 |
+
|
| 338 |
+
# Still short? tile from what's available
|
| 339 |
+
have = sum(p.shape[0] for p in left_parts)
|
| 340 |
+
if have < deficit:
|
| 341 |
+
base = out if out.shape[0] > 0 else (recent_tokens if recent_tokens.shape[0] > 0 else original_tokens)
|
| 342 |
+
reps = int(np.ceil((deficit - have) / max(1, base.shape[0])))
|
| 343 |
+
left_parts.append(np.tile(base, (reps, 1))[: (deficit - have)])
|
| 344 |
+
|
| 345 |
+
left = np.concatenate(left_parts, axis=0)
|
| 346 |
+
out = np.concatenate([left[-deficit:], out], axis=0)
|
| 347 |
+
|
| 348 |
+
# 7) Final guard to exact length
|
| 349 |
+
if out.shape[0] > ctx_frames:
|
| 350 |
+
out = out[-ctx_frames:]
|
| 351 |
+
elif out.shape[0] < ctx_frames:
|
| 352 |
+
reps = int(np.ceil(ctx_frames / max(1, out.shape[0])))
|
| 353 |
+
out = np.tile(out, (reps, 1))[-ctx_frames:]
|
| 354 |
+
|
| 355 |
+
# 8) Depth guard
|
| 356 |
+
if out.shape[1] != depth:
|
| 357 |
+
out = out[:, :depth]
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _realign_emit_pointer_to_bar(self, sr_model: int):
|
| 362 |
+
"""Advance _next_emit_start to the next bar boundary in model-sample space."""
|
| 363 |
+
bar_samps = int(round(self._seconds_per_bar() * sr_model))
|
| 364 |
+
if bar_samps <= 0:
|
| 365 |
+
return
|
| 366 |
+
phase = self._next_emit_start % bar_samps
|
| 367 |
+
if phase != 0:
|
| 368 |
+
self._next_emit_start += (bar_samps - phase)
|
| 369 |
+
|
| 370 |
+
def _prepare_stream_for_reseed_handoff(self):
|
| 371 |
+
# OLD: keep crossfade tail -> causes phase offset
|
| 372 |
+
# sr = int(self.mrt.sample_rate)
|
| 373 |
+
# xfade_s = float(self.mrt.config.crossfade_length)
|
| 374 |
+
# xfade_n = int(round(xfade_s * sr))
|
| 375 |
+
# if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
|
| 376 |
+
# tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream
|
| 377 |
+
# self._stream = tail.copy()
|
| 378 |
+
# else:
|
| 379 |
+
# self._stream = None
|
| 380 |
+
|
| 381 |
+
# NEW: throw away the tail completely; start fresh
|
| 382 |
+
self._stream = None
|
| 383 |
+
|
| 384 |
+
self._next_emit_start = 0
|
| 385 |
self._needs_bar_realign = True
|
|
|
|
| 386 |
|
| 387 |
def reseed_splice(self, recent_wav, anchor_bars: float):
|
| 388 |
+
"""
|
| 389 |
+
Token-splice reseed queued for the next bar boundary between chunks.
|
| 390 |
+
"""
|
| 391 |
with self._lock:
|
| 392 |
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
| 393 |
self._original_context_tokens = np.copy(self.state.context_tokens)
|
| 394 |
|
| 395 |
+
recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
|
|
|
|
| 396 |
new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
|
| 397 |
|
| 398 |
+
# Queue it; the run loop will install right after we finish the current slice
|
| 399 |
self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
|
| 400 |
+
|
| 401 |
+
# install the new context window
|
| 402 |
new_state = self.mrt.init_state()
|
| 403 |
new_state.context_tokens = new_ctx
|
| 404 |
self.state = new_state
|
| 405 |
|
| 406 |
+
self._prepare_stream_for_reseed_handoff()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
+
# optional: ask streamer to drop an intro crossfade worth of audio right after reseed
|
| 409 |
+
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
def run(self):
|
| 412 |
+
"""Main worker loop — generate into a continuous stream, then emit bar-aligned slices."""
|
| 413 |
+
spb = self._seconds_per_bar() # seconds per bar
|
| 414 |
+
chunk_secs = self.params.bars_per_chunk * spb
|
| 415 |
+
xfade = float(self.mrt.config.crossfade_length) # seconds
|
| 416 |
+
sr = int(self.mrt.sample_rate)
|
| 417 |
+
chunk_samps = int(round(chunk_secs * sr))
|
| 418 |
+
|
| 419 |
+
def _need(first_chunk_extra=False):
|
| 420 |
+
"""How many more samples we still need in the stream to emit next slice."""
|
| 421 |
+
have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
|
| 422 |
+
want = chunk_samps
|
| 423 |
if first_chunk_extra:
|
| 424 |
+
# reserve two bars extra so first-chunk onset alignment has material
|
| 425 |
+
want += int(round(2 * spb * sr))
|
| 426 |
+
return max(0, want - have)
|
| 427 |
+
|
| 428 |
+
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
|
| 429 |
+
if x.ndim == 2: x = x.mean(axis=1)
|
| 430 |
+
x = np.abs(x).astype(np.float32)
|
| 431 |
+
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
| 432 |
+
if w > 1:
|
| 433 |
+
kern = np.ones(w, dtype=np.float32) / float(w)
|
| 434 |
+
x = np.convolve(x, kern, mode="same")
|
| 435 |
+
d = np.diff(x, prepend=x[:1])
|
| 436 |
+
d[d < 0] = 0.0
|
| 437 |
+
return d
|
| 438 |
+
|
| 439 |
+
def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
|
| 440 |
+
"""Tempo-aware first-downbeat offset (positive => model late)."""
|
| 441 |
+
try:
|
| 442 |
+
max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
|
| 443 |
+
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
| 444 |
+
n_bar = int(round(spb * sr))
|
| 445 |
+
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
| 446 |
+
gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
|
| 447 |
+
if ref_tail.size == 0 or gen_head.size == 0:
|
| 448 |
+
return 0
|
| 449 |
+
|
| 450 |
+
# envelopes + z-score
|
| 451 |
+
import numpy as np
|
| 452 |
+
def _z(a):
|
| 453 |
+
m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
|
| 454 |
+
e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
|
| 455 |
+
e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
|
| 456 |
+
|
| 457 |
+
# upsample x4 for finer lag
|
| 458 |
+
def _upsample(a, r=4):
|
| 459 |
+
n = len(a); grid = np.arange(n, dtype=np.float32)
|
| 460 |
+
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
| 461 |
+
return np.interp(fine, grid, a).astype(np.float32)
|
| 462 |
+
up = 4
|
| 463 |
+
e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
|
| 464 |
+
|
| 465 |
+
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
| 466 |
+
seg = min(len(e_ref_u), len(e_gen_u))
|
| 467 |
+
e_ref_u = e_ref_u[-seg:]
|
| 468 |
+
pad = np.zeros(max_lag_u, dtype=np.float32)
|
| 469 |
+
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
| 470 |
+
|
| 471 |
+
best_lag_u, best_score = 0, -1e9
|
| 472 |
+
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
| 473 |
+
start = max_lag_u + lag_u
|
| 474 |
+
b = e_gen_u_pad[start : start + seg]
|
| 475 |
+
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
| 476 |
+
score = float(np.dot(e_ref_u, b) / denom)
|
| 477 |
+
if score > best_score:
|
| 478 |
+
best_score, best_lag_u = score, lag_u
|
| 479 |
+
return int(round(best_lag_u / up))
|
| 480 |
+
except Exception:
|
| 481 |
+
return 0
|
| 482 |
+
|
| 483 |
+
print("🚀 JamWorker started (bar-aligned streaming)…")
|
| 484 |
|
| 485 |
while not self._stop_event.is_set():
|
| 486 |
if not self._should_generate_next_chunk():
|
| 487 |
time.sleep(0.25)
|
| 488 |
continue
|
| 489 |
|
| 490 |
+
# 1) Generate until we have enough material in the stream
|
| 491 |
+
need = _need(first_chunk_extra=(self.idx == 0))
|
| 492 |
+
while need > 0 and not self._stop_event.is_set():
|
| 493 |
with self._lock:
|
| 494 |
style_vec = self.params.style_vec
|
| 495 |
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
| 496 |
+
self.mrt.temperature = float(self.params.temperature)
|
| 497 |
+
self.mrt.topk = int(self.params.topk)
|
|
|
|
| 498 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 499 |
+
self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
|
| 500 |
+
need = _need(first_chunk_extra=(self.idx == 0))
|
| 501 |
|
| 502 |
if self._stop_event.is_set():
|
| 503 |
break
|
| 504 |
|
| 505 |
+
# 2) One-time: align the emit pointer to the groove
|
| 506 |
if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
|
| 507 |
ref_loop = self._reseed_ref_loop or self.params.combined_loop
|
| 508 |
if ref_loop is not None:
|
| 509 |
+
head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
|
| 510 |
+
seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
|
| 511 |
+
gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
|
| 512 |
+
offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
|
| 513 |
+
if offs != 0:
|
| 514 |
+
self._next_emit_start = max(0, self._next_emit_start + offs)
|
| 515 |
+
print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
|
| 516 |
+
self._realign_emit_pointer_to_bar(sr)
|
| 517 |
self._needs_bar_realign = False
|
| 518 |
self._reseed_ref_loop = None
|
| 519 |
|
| 520 |
+
# 3) Emit exactly bars_per_chunk × spb from the stream
|
| 521 |
+
start = self._next_emit_start
|
| 522 |
+
end = start + chunk_samps
|
| 523 |
+
if end > self._stream.shape[0]:
|
| 524 |
+
# shouldn't happen often; generate a bit more and loop
|
| 525 |
+
continue
|
| 526 |
|
| 527 |
+
slice_ = self._stream[start:end]
|
| 528 |
+
self._next_emit_start = end
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
+
y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
|
|
|
|
| 531 |
|
| 532 |
+
# 4) Post-processing / loudness
|
| 533 |
if self.idx == 0 and self.params.ref_loop is not None:
|
| 534 |
y, _ = match_loudness_to_reference(
|
| 535 |
self.params.ref_loop, y,
|
|
|
|
| 539 |
else:
|
| 540 |
apply_micro_fades(y, 3)
|
| 541 |
|
| 542 |
+
# 5) Resample + exact-length snap + encode
|
| 543 |
+
b64, meta = self._snap_and_encode(
|
| 544 |
+
y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
|
| 545 |
+
)
|
| 546 |
+
meta["xfade_seconds"] = xfade
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
+
# 6) Publish
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
with self._lock:
|
| 550 |
self.idx += 1
|
| 551 |
+
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
|
|
|
|
|
|
|
|
|
| 552 |
if len(self.outbox) > 10:
|
| 553 |
cutoff = self._last_delivered_index - 5
|
| 554 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
| 555 |
|
| 556 |
+
# 👉 If a reseed was requested, apply it *now*, between chunks
|
| 557 |
if self._pending_reseed is not None:
|
| 558 |
pkg = self._pending_reseed
|
| 559 |
self._pending_reseed = None
|
| 560 |
|
| 561 |
new_state = self.mrt.init_state()
|
| 562 |
+
new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
|
| 563 |
self.state = new_state
|
| 564 |
|
| 565 |
+
# start a fresh stream and schedule one-time alignment
|
| 566 |
self._stream = None
|
| 567 |
+
self._next_emit_start = 0
|
| 568 |
+
self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
self._needs_bar_realign = True
|
| 570 |
|
| 571 |
+
print("🔁 Reseed installed at bar boundary; will realign before next slice")
|
| 572 |
|
| 573 |
+
print(f"✅ Completed chunk {self.idx}")
|
| 574 |
+
|
| 575 |
+
print("🛑 JamWorker stopped")
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -109,81 +109,55 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
| 109 |
|
| 110 |
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
-
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4
|
| 113 |
"""
|
| 114 |
Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
|
| 115 |
-
whole-bar boundary in codec-frame space.
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
"""
|
|
|
|
|
|
|
| 119 |
if tokens is None:
|
| 120 |
raise ValueError("tokens is None")
|
| 121 |
tokens = np.asarray(tokens)
|
| 122 |
if tokens.ndim == 1:
|
| 123 |
-
tokens = tokens[:, None]
|
| 124 |
|
| 125 |
T = tokens.shape[0]
|
| 126 |
if T == 0:
|
| 127 |
return tokens
|
| 128 |
|
| 129 |
fps = float(fps)
|
| 130 |
-
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
best_end = candidate_end
|
| 156 |
-
|
| 157 |
-
end_idx = best_end
|
| 158 |
-
start_idx = max(0, end_idx - ctx_frames)
|
| 159 |
-
|
| 160 |
-
window = tiled[start_idx:end_idx]
|
| 161 |
-
|
| 162 |
-
# Report timing info for debugging
|
| 163 |
-
actual_bars = end_idx / frames_per_bar_f
|
| 164 |
-
print(f"Context aligned to {actual_bars:.3f} bars (error: {best_error:.4f})")
|
| 165 |
-
|
| 166 |
-
else:
|
| 167 |
-
# Original logic for integer frames per bar
|
| 168 |
-
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
| 169 |
-
tiled = np.tile(tokens, (reps, 1))
|
| 170 |
-
total = tiled.shape[0]
|
| 171 |
-
|
| 172 |
-
k_bars = int(np.floor(total / frames_per_bar_f))
|
| 173 |
-
if k_bars <= 0:
|
| 174 |
-
window = tiled[-ctx_frames:]
|
| 175 |
-
return window
|
| 176 |
-
|
| 177 |
-
end_idx = int(round(k_bars * frames_per_bar_f))
|
| 178 |
-
end_idx = min(max(end_idx, ctx_frames), total)
|
| 179 |
-
start_idx = end_idx - ctx_frames
|
| 180 |
-
if start_idx < 0:
|
| 181 |
-
start_idx = 0
|
| 182 |
-
end_idx = ctx_frames
|
| 183 |
-
|
| 184 |
-
window = tiled[start_idx:end_idx]
|
| 185 |
-
|
| 186 |
-
# Ensure exact length
|
| 187 |
if window.shape[0] < ctx_frames:
|
| 188 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
| 189 |
window = np.vstack([window, pad])[:ctx_frames]
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
+
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
| 113 |
"""
|
| 114 |
Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
|
| 115 |
+
whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
|
| 116 |
+
|
| 117 |
+
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
|
| 118 |
+
bpm: float
|
| 119 |
+
fps: float (codec frames per second; keep this as float)
|
| 120 |
+
ctx_frames: int (length of context window in codec frames)
|
| 121 |
+
beats_per_bar: int
|
| 122 |
"""
|
| 123 |
+
|
| 124 |
+
|
| 125 |
if tokens is None:
|
| 126 |
raise ValueError("tokens is None")
|
| 127 |
tokens = np.asarray(tokens)
|
| 128 |
if tokens.ndim == 1:
|
| 129 |
+
tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
|
| 130 |
|
| 131 |
T = tokens.shape[0]
|
| 132 |
if T == 0:
|
| 133 |
return tokens
|
| 134 |
|
| 135 |
fps = float(fps)
|
| 136 |
+
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
|
| 137 |
+
|
| 138 |
+
# Tile a little more than we need so we can always snap the END to a bar boundary
|
| 139 |
+
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
| 140 |
+
tiled = np.tile(tokens, (reps, 1))
|
| 141 |
+
total = tiled.shape[0]
|
| 142 |
+
|
| 143 |
+
# How many whole bars fit?
|
| 144 |
+
k_bars = int(np.floor(total / frames_per_bar_f))
|
| 145 |
+
if k_bars <= 0:
|
| 146 |
+
# Fallback: just take the last ctx_frames
|
| 147 |
+
window = tiled[-ctx_frames:]
|
| 148 |
+
return window
|
| 149 |
+
|
| 150 |
+
# Snap END index to the nearest integer frame at a whole-bar boundary
|
| 151 |
+
end_idx = int(round(k_bars * frames_per_bar_f))
|
| 152 |
+
end_idx = min(max(end_idx, ctx_frames), total)
|
| 153 |
+
start_idx = end_idx - ctx_frames
|
| 154 |
+
if start_idx < 0:
|
| 155 |
+
start_idx = 0
|
| 156 |
+
end_idx = ctx_frames
|
| 157 |
+
|
| 158 |
+
window = tiled[start_idx:end_idx]
|
| 159 |
+
|
| 160 |
+
# Guard against rare off-by-one due to rounding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
if window.shape[0] < ctx_frames:
|
| 162 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
| 163 |
window = np.vstack([window, pad])[:ctx_frames]
|