Commit
Β·
c4dc2c2
1
Parent(s):
1b98b73
fixing flam
Browse files- jam_worker.py +77 -52
jam_worker.py
CHANGED
|
@@ -355,49 +355,74 @@ class JamWorker(threading.Thread):
|
|
| 355 |
chunk_secs = self.params.bars_per_chunk * spb
|
| 356 |
xfade = float(self.mrt.config.crossfade_length) # seconds
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
print("π JamWorker started with flow control...")
|
| 391 |
|
| 392 |
while not self._stop_event.is_set():
|
| 393 |
# Donβt get too far ahead of the consumer
|
| 394 |
if not self._should_generate_next_chunk():
|
| 395 |
-
|
| 396 |
-
# (kept short so stop() stays responsive)
|
| 397 |
time.sleep(0.5)
|
| 398 |
continue
|
| 399 |
|
| 400 |
-
# Snapshot knobs + compute index
|
| 401 |
with self._lock:
|
| 402 |
style_vec = self.params.style_vec
|
| 403 |
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
|
@@ -409,12 +434,10 @@ class JamWorker(threading.Thread):
|
|
| 409 |
self.last_chunk_started_at = time.time()
|
| 410 |
|
| 411 |
# ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
|
| 412 |
-
#
|
| 413 |
assembled = 0.0
|
| 414 |
chunks = []
|
| 415 |
-
|
| 416 |
while assembled < chunk_secs and not self._stop_event.is_set():
|
| 417 |
-
# generate_chunk returns (au.Waveform, new_state)
|
| 418 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 419 |
chunks.append(wav)
|
| 420 |
L = wav.samples.shape[0] / float(self.mrt.sample_rate)
|
|
@@ -423,27 +446,30 @@ class JamWorker(threading.Thread):
|
|
| 423 |
if self._stop_event.is_set():
|
| 424 |
break
|
| 425 |
|
| 426 |
-
# ---- Stitch
|
| 427 |
-
|
| 428 |
-
# Preferred path if you've added the new param in utils.stitch_generated
|
| 429 |
-
y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
|
| 430 |
-
except TypeError:
|
| 431 |
-
# Backward-compatible: local stitcher that keeps the head
|
| 432 |
-
y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
|
| 433 |
-
|
| 434 |
-
# Hard trim to the exact musical duration (still at model SR)
|
| 435 |
y = hard_trim_seconds(y, chunk_secs)
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
# ---- Post-processing ----
|
| 438 |
if next_idx == 1 and self.params.ref_loop is not None:
|
| 439 |
-
# match loudness to the provided reference on the very first audible chunk
|
| 440 |
y, _ = match_loudness_to_reference(
|
| 441 |
self.params.ref_loop, y,
|
| 442 |
method=self.params.loudness_mode,
|
| 443 |
headroom_db=self.params.headroom_db
|
| 444 |
)
|
| 445 |
else:
|
| 446 |
-
# light micro-fades to guard against clicks
|
| 447 |
apply_micro_fades(y, 3)
|
| 448 |
|
| 449 |
# ---- Resample + bar-snap + encode ----
|
|
@@ -453,14 +479,12 @@ class JamWorker(threading.Thread):
|
|
| 453 |
target_sr=self.params.target_sr,
|
| 454 |
bars=self.params.bars_per_chunk
|
| 455 |
)
|
| 456 |
-
#
|
| 457 |
-
meta["xfade_seconds"] = xfade
|
| 458 |
|
| 459 |
-
# ---- Publish
|
| 460 |
with self._lock:
|
| 461 |
self.idx = next_idx
|
| 462 |
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
| 463 |
-
# Keep outbox bounded (trim far-behind entries)
|
| 464 |
if len(self.outbox) > 10:
|
| 465 |
cutoff = self._last_delivered_index - 5
|
| 466 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
|
@@ -469,3 +493,4 @@ class JamWorker(threading.Thread):
|
|
| 469 |
print(f"β
Completed chunk {next_idx}")
|
| 470 |
|
| 471 |
print("π JamWorker stopped")
|
|
|
|
|
|
| 355 |
chunk_secs = self.params.bars_per_chunk * spb
|
| 356 |
xfade = float(self.mrt.config.crossfade_length) # seconds
|
| 357 |
|
| 358 |
+
# ---- tiny helper: mono + simple envelope ----
|
| 359 |
+
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 20.0) -> np.ndarray:
|
| 360 |
+
if x.ndim == 2:
|
| 361 |
+
x = x.mean(axis=1)
|
| 362 |
+
x = np.abs(x).astype(np.float32)
|
| 363 |
+
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
| 364 |
+
if w == 1:
|
| 365 |
+
return x
|
| 366 |
+
kern = np.ones(w, dtype=np.float32) / float(w)
|
| 367 |
+
# moving average (same length)
|
| 368 |
+
return np.convolve(x, kern, mode="same")
|
| 369 |
+
|
| 370 |
+
# ---- estimate how late the first downbeat is (<= max_ms) ----
|
| 371 |
+
def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, max_ms: int = 120) -> int:
|
| 372 |
+
try:
|
| 373 |
+
# resample ref to model SR if needed
|
| 374 |
+
ref = ref_loop_wav
|
| 375 |
+
if ref.sample_rate != sr:
|
| 376 |
+
ref = ref.resample(sr)
|
| 377 |
+
# last 1 bar of the reference (what the model just "heard")
|
| 378 |
+
n_bar = int(round(spb * sr))
|
| 379 |
+
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
| 380 |
+
# first 2 bars of the generated chunk (search window)
|
| 381 |
+
gen_head = gen_wav.samples[: int(2 * n_bar), :]
|
| 382 |
+
if ref_tail.size == 0 or gen_head.size == 0:
|
| 383 |
+
return 0
|
| 384 |
+
|
| 385 |
+
# envelopes
|
| 386 |
+
e_ref = _mono_env(ref_tail, sr)
|
| 387 |
+
e_gen = _mono_env(gen_head, sr)
|
| 388 |
+
|
| 389 |
+
max_lag = int(round((max_ms / 1000.0) * sr))
|
| 390 |
+
# ensure the window is long enough
|
| 391 |
+
seg = min(len(e_ref), len(e_gen))
|
| 392 |
+
e_ref = e_ref[-seg:]
|
| 393 |
+
e_gen = e_gen[: seg + max_lag] # allow positive lag (gen late)
|
| 394 |
+
|
| 395 |
+
if len(e_gen) < seg:
|
| 396 |
+
return 0
|
| 397 |
+
|
| 398 |
+
# brute-force short-range correlation (gen late => positive lag)
|
| 399 |
+
best_lag = 0
|
| 400 |
+
best_score = -1e9
|
| 401 |
+
for lag in range(0, max_lag + 1):
|
| 402 |
+
a = e_ref
|
| 403 |
+
b = e_gen[lag : lag + seg]
|
| 404 |
+
if len(b) != seg:
|
| 405 |
+
break
|
| 406 |
+
# normalized dot to be robust-ish
|
| 407 |
+
denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
|
| 408 |
+
score = float(np.dot(a, b) / denom)
|
| 409 |
+
if score > best_score:
|
| 410 |
+
best_score = score
|
| 411 |
+
best_lag = lag
|
| 412 |
+
return int(best_lag)
|
| 413 |
+
except Exception:
|
| 414 |
+
return 0
|
| 415 |
|
| 416 |
print("π JamWorker started with flow control...")
|
| 417 |
|
| 418 |
while not self._stop_event.is_set():
|
| 419 |
# Donβt get too far ahead of the consumer
|
| 420 |
if not self._should_generate_next_chunk():
|
| 421 |
+
print("βΈοΈ Buffer full, waiting for consumption...")
|
|
|
|
| 422 |
time.sleep(0.5)
|
| 423 |
continue
|
| 424 |
|
| 425 |
+
# Snapshot knobs + compute index
|
| 426 |
with self._lock:
|
| 427 |
style_vec = self.params.style_vec
|
| 428 |
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
|
|
|
| 434 |
self.last_chunk_started_at = time.time()
|
| 435 |
|
| 436 |
# ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
|
| 437 |
+
# First sub-chunk contributes full L; subsequent contribute (L - xfade)
|
| 438 |
assembled = 0.0
|
| 439 |
chunks = []
|
|
|
|
| 440 |
while assembled < chunk_secs and not self._stop_event.is_set():
|
|
|
|
| 441 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 442 |
chunks.append(wav)
|
| 443 |
L = wav.samples.shape[0] / float(self.mrt.sample_rate)
|
|
|
|
| 446 |
if self._stop_event.is_set():
|
| 447 |
break
|
| 448 |
|
| 449 |
+
# ---- Stitch (utils drops the very first model pre-roll) & trim at model SR ----
|
| 450 |
+
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
y = hard_trim_seconds(y, chunk_secs)
|
| 452 |
|
| 453 |
+
# ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
|
| 454 |
+
if next_idx == 1 and self.params.combined_loop is not None:
|
| 455 |
+
offset = _estimate_first_offset_samples(
|
| 456 |
+
self.params.combined_loop, y, int(self.mrt.sample_rate), max_ms=120
|
| 457 |
+
)
|
| 458 |
+
if offset > 0:
|
| 459 |
+
# Trim the head by the detected offset; we'll snap length later
|
| 460 |
+
y.samples = y.samples[offset:, :]
|
| 461 |
+
print(f"π― First-chunk offset compensation: -{offset/self.mrt.sample_rate:.3f}s")
|
| 462 |
+
# hard trim again (defensive), remaining length exactness happens in _snap_and_encode
|
| 463 |
+
y = hard_trim_seconds(y, chunk_secs)
|
| 464 |
+
|
| 465 |
# ---- Post-processing ----
|
| 466 |
if next_idx == 1 and self.params.ref_loop is not None:
|
|
|
|
| 467 |
y, _ = match_loudness_to_reference(
|
| 468 |
self.params.ref_loop, y,
|
| 469 |
method=self.params.loudness_mode,
|
| 470 |
headroom_db=self.params.headroom_db
|
| 471 |
)
|
| 472 |
else:
|
|
|
|
| 473 |
apply_micro_fades(y, 3)
|
| 474 |
|
| 475 |
# ---- Resample + bar-snap + encode ----
|
|
|
|
| 479 |
target_sr=self.params.target_sr,
|
| 480 |
bars=self.params.bars_per_chunk
|
| 481 |
)
|
| 482 |
+
meta["xfade_seconds"] = xfade # tiny hint for client if you want butter at chunk joins
|
|
|
|
| 483 |
|
| 484 |
+
# ---- Publish ----
|
| 485 |
with self._lock:
|
| 486 |
self.idx = next_idx
|
| 487 |
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
|
|
|
| 488 |
if len(self.outbox) > 10:
|
| 489 |
cutoff = self._last_delivered_index - 5
|
| 490 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
|
|
|
| 493 |
print(f"β
Completed chunk {next_idx}")
|
| 494 |
|
| 495 |
print("π JamWorker stopped")
|
| 496 |
+
|