Commit
·
ca4c65f
1
Parent(s):
d96e1a0
it wasn't the loudness matching. _append_model_chunk_and_spool debug now
Browse files- jam_worker.py +38 -94
jam_worker.py
CHANGED
|
@@ -19,6 +19,13 @@ from utils import (
|
|
| 19 |
wav_bytes_base64,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# -----------------------------
|
| 23 |
# Data classes
|
| 24 |
# -----------------------------
|
|
@@ -430,14 +437,13 @@ class JamWorker(threading.Thread):
|
|
| 430 |
Conservative boundary fix:
|
| 431 |
- Emit body+tail immediately (target SR), unchanged from your original behavior.
|
| 432 |
- On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
|
| 437 |
This keeps external timing and bar alignment identical, but removes the audible
|
| 438 |
fade-to-zero at chunk ends.
|
| 439 |
"""
|
| 440 |
-
|
| 441 |
|
| 442 |
# ---- unpack model-rate samples ----
|
| 443 |
s = wav.samples.astype(np.float32, copy=False)
|
|
@@ -471,6 +477,10 @@ class JamWorker(threading.Thread):
|
|
| 471 |
y_mixed = to_target(mixed_model.astype(np.float32))
|
| 472 |
Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
|
| 475 |
# Use the *smaller* of the two lengths to be safe.
|
| 476 |
Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
|
|
@@ -510,6 +520,8 @@ class JamWorker(threading.Thread):
|
|
| 510 |
if body.size:
|
| 511 |
y_body = to_target(body.astype(np.float32))
|
| 512 |
if y_body.size:
|
|
|
|
|
|
|
| 513 |
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 514 |
self._spool_written += y_body.shape[0]
|
| 515 |
else:
|
|
@@ -518,6 +530,8 @@ class JamWorker(threading.Thread):
|
|
| 518 |
body = s[xfade_n:, :]
|
| 519 |
y_body = to_target(body.astype(np.float32))
|
| 520 |
if y_body.size:
|
|
|
|
|
|
|
| 521 |
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 522 |
self._spool_written += y_body.shape[0]
|
| 523 |
# No tail to remember this round
|
|
@@ -531,12 +545,14 @@ class JamWorker(threading.Thread):
|
|
| 531 |
y_tail = to_target(tail.astype(np.float32))
|
| 532 |
Ltail = int(y_tail.shape[0])
|
| 533 |
if Ltail:
|
|
|
|
|
|
|
| 534 |
self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
|
| 535 |
self._spool_written += Ltail
|
| 536 |
self._pending_tail_model = tail.copy()
|
| 537 |
self._pending_tail_target_len = Ltail
|
| 538 |
else:
|
| 539 |
-
# Nothing appended (resampler
|
| 540 |
self._pending_tail_model = tail.copy()
|
| 541 |
self._pending_tail_target_len = 0
|
| 542 |
else:
|
|
@@ -544,6 +560,7 @@ class JamWorker(threading.Thread):
|
|
| 544 |
self._pending_tail_target_len = 0
|
| 545 |
|
| 546 |
|
|
|
|
| 547 |
def _should_generate_next_chunk(self) -> bool:
|
| 548 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
| 549 |
# (explicit ack from client) or last *delivered* (implicit ack).
|
|
@@ -552,97 +569,20 @@ class JamWorker(threading.Thread):
|
|
| 552 |
return self.idx <= (horizon_anchor + self._max_buffer_ahead)
|
| 553 |
|
| 554 |
def _emit_ready(self):
|
| 555 |
-
"""Emit next chunk(s) if the spool has enough samples.
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
QDB_SILENCE = -55.0
|
| 559 |
-
EPS = 1e-12
|
| 560 |
-
|
| 561 |
-
def rms_dbfs(x: np.ndarray) -> float:
|
| 562 |
-
if x.ndim == 2:
|
| 563 |
-
x = x.mean(axis=1)
|
| 564 |
-
rms = float(np.sqrt(np.mean(np.square(x)) + EPS))
|
| 565 |
-
return 20.0 * np.log10(max(rms, EPS))
|
| 566 |
-
|
| 567 |
-
def qbar_rms_dbfs(x: np.ndarray, seg_len: int) -> list[float]:
|
| 568 |
-
if x.ndim == 2:
|
| 569 |
-
mono = x.mean(axis=1)
|
| 570 |
-
else:
|
| 571 |
-
mono = x
|
| 572 |
-
N = mono.shape[0]
|
| 573 |
-
vals = []
|
| 574 |
-
for i in range(0, N, seg_len):
|
| 575 |
-
seg = mono[i:min(i + seg_len, N)]
|
| 576 |
-
if seg.size == 0:
|
| 577 |
-
break
|
| 578 |
-
r = float(np.sqrt(np.mean(seg * seg) + EPS))
|
| 579 |
-
vals.append(20.0 * np.log10(max(r, EPS)))
|
| 580 |
-
return vals
|
| 581 |
-
|
| 582 |
-
def fmt_db_list(vals):
|
| 583 |
-
return ['%5.1f' % v for v in vals[:8]]
|
| 584 |
-
|
| 585 |
-
def extract_gain_db(g):
|
| 586 |
-
# Accept float/int, dict{'gain_db': ...}, tuple/list, or None
|
| 587 |
-
if g is None:
|
| 588 |
-
return None
|
| 589 |
-
if isinstance(g, (int, float)):
|
| 590 |
-
return float(g)
|
| 591 |
-
if isinstance(g, dict):
|
| 592 |
-
for k in ('gain_db', 'gain', 'applied_gain_db'):
|
| 593 |
-
if k in g:
|
| 594 |
-
try:
|
| 595 |
-
return float(g[k])
|
| 596 |
-
except Exception:
|
| 597 |
-
pass
|
| 598 |
-
return None
|
| 599 |
-
if isinstance(g, (list, tuple)) and g:
|
| 600 |
-
try:
|
| 601 |
-
return float(g[0])
|
| 602 |
-
except Exception:
|
| 603 |
-
return None
|
| 604 |
-
return None
|
| 605 |
-
|
| 606 |
while True:
|
| 607 |
start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
|
| 608 |
if end > self._spool_written:
|
| 609 |
-
break
|
| 610 |
-
|
| 611 |
loop = self._spool[start:end]
|
| 612 |
|
| 613 |
-
#
|
| 614 |
-
spb = self._bar_clock.bar_samps
|
| 615 |
-
qlen = max(1, spb // 4)
|
| 616 |
-
q_rms_pre = qbar_rms_dbfs(loop, qlen)
|
| 617 |
-
silent_marks_pre = ["🟢" if v > QDB_SILENCE else "🟥" for v in q_rms_pre[:8]]
|
| 618 |
-
print(f"[emit idx={self.idx}] pre-LM qRMS dBFS: {fmt_db_list(q_rms_pre)} {''.join(silent_marks_pre)}")
|
| 619 |
-
|
| 620 |
-
# Loudness match (optional)
|
| 621 |
-
gain_db_applied_raw = None
|
| 622 |
if self.params.ref_loop is not None and self.params.loudness_mode != "none":
|
| 623 |
ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
|
| 624 |
wav = au.Waveform(loop.copy(), int(self.params.target_sr))
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
ref, wav,
|
| 628 |
-
method=self.params.loudness_mode,
|
| 629 |
-
headroom_db=self.params.headroom_db
|
| 630 |
-
)
|
| 631 |
-
loop = matched.samples
|
| 632 |
-
except Exception as e:
|
| 633 |
-
print(f"[emit idx={self.idx}] loudness-match ERROR: {e}; proceeding with un-matched audio")
|
| 634 |
-
|
| 635 |
-
gain_db = extract_gain_db(gain_db_applied_raw)
|
| 636 |
-
|
| 637 |
-
# ---- post-LM diagnostics ----
|
| 638 |
-
q_rms_post = qbar_rms_dbfs(loop, qlen)
|
| 639 |
-
silent_marks_post = ["🟢" if v > QDB_SILENCE else "🟥" for v in q_rms_post[:8]]
|
| 640 |
-
if gain_db is None:
|
| 641 |
-
print(f"[emit idx={self.idx}] post-LM qRMS dBFS: {fmt_db_list(q_rms_post)} {''.join(silent_marks_post)} (LM: none)")
|
| 642 |
-
else:
|
| 643 |
-
print(f"[emit idx={self.idx}] post-LM qRMS dBFS: {fmt_db_list(q_rms_post)} {''.join(silent_marks_post)} (LM gain {gain_db:+.2f} dB)")
|
| 644 |
|
| 645 |
-
# Encode & ship
|
| 646 |
audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
|
| 647 |
meta = {
|
| 648 |
"bpm": float(self.params.bpm),
|
|
@@ -659,28 +599,34 @@ class JamWorker(threading.Thread):
|
|
| 659 |
}
|
| 660 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
with self._cv:
|
| 663 |
self._outbox[self.idx] = chunk
|
| 664 |
self._cv.notify_all()
|
| 665 |
-
|
| 666 |
-
print(f"[emit idx={self.idx}] slice [{start}:{end}] (len={end-start}), spool_written={self._spool_written}")
|
| 667 |
self.idx += 1
|
| 668 |
|
| 669 |
-
#
|
| 670 |
with self._lock:
|
|
|
|
| 671 |
if self._pending_token_splice is not None:
|
| 672 |
spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
|
| 673 |
try:
|
|
|
|
| 674 |
self.state.context_tokens = spliced
|
| 675 |
self._pending_token_splice = None
|
| 676 |
-
print(f"[emit idx={self.idx}] installed token splice (in-place)")
|
| 677 |
except Exception:
|
|
|
|
| 678 |
new_state = self.mrt.init_state()
|
| 679 |
new_state.context_tokens = spliced
|
| 680 |
self.state = new_state
|
| 681 |
self._model_stream = None
|
| 682 |
self._pending_token_splice = None
|
| 683 |
-
print(f"[emit idx={self.idx}] installed token splice (reinit state)")
|
| 684 |
elif self._pending_reseed is not None:
|
| 685 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 686 |
new_state = self.mrt.init_state()
|
|
@@ -688,8 +634,6 @@ class JamWorker(threading.Thread):
|
|
| 688 |
self.state = new_state
|
| 689 |
self._model_stream = None
|
| 690 |
self._pending_reseed = None
|
| 691 |
-
print(f"[emit idx={self.idx}] performed full reseed")
|
| 692 |
-
|
| 693 |
|
| 694 |
# ---------- main loop ----------
|
| 695 |
|
|
|
|
| 19 |
wav_bytes_base64,
|
| 20 |
)
|
| 21 |
|
| 22 |
+
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
| 23 |
+
|
| 24 |
+
if x.ndim == 2:
|
| 25 |
+
x = x.mean(axis=1)
|
| 26 |
+
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
| 27 |
+
return 20.0 * np.log10(max(r, 1e-12))
|
| 28 |
+
|
| 29 |
# -----------------------------
|
| 30 |
# Data classes
|
| 31 |
# -----------------------------
|
|
|
|
| 437 |
Conservative boundary fix:
|
| 438 |
- Emit body+tail immediately (target SR), unchanged from your original behavior.
|
| 439 |
- On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
|
| 440 |
+
resample it, and overwrite the last `_pending_tail_target_len` samples in the
|
| 441 |
+
target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
|
| 442 |
+
remember THIS chunk's tail length at target SR for the next correction.
|
| 443 |
|
| 444 |
This keeps external timing and bar alignment identical, but removes the audible
|
| 445 |
fade-to-zero at chunk ends.
|
| 446 |
"""
|
|
|
|
| 447 |
|
| 448 |
# ---- unpack model-rate samples ----
|
| 449 |
s = wav.samples.astype(np.float32, copy=False)
|
|
|
|
| 477 |
y_mixed = to_target(mixed_model.astype(np.float32))
|
| 478 |
Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
|
| 479 |
|
| 480 |
+
# DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
|
| 481 |
+
if y_mixed.size:
|
| 482 |
+
print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
|
| 483 |
+
|
| 484 |
# Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
|
| 485 |
# Use the *smaller* of the two lengths to be safe.
|
| 486 |
Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
|
|
|
|
| 520 |
if body.size:
|
| 521 |
y_body = to_target(body.astype(np.float32))
|
| 522 |
if y_body.size:
|
| 523 |
+
# DEBUG: body RMS we are actually appending
|
| 524 |
+
print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
| 525 |
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 526 |
self._spool_written += y_body.shape[0]
|
| 527 |
else:
|
|
|
|
| 530 |
body = s[xfade_n:, :]
|
| 531 |
y_body = to_target(body.astype(np.float32))
|
| 532 |
if y_body.size:
|
| 533 |
+
# DEBUG: body RMS in short-chunk path
|
| 534 |
+
print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
| 535 |
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 536 |
self._spool_written += y_body.shape[0]
|
| 537 |
# No tail to remember this round
|
|
|
|
| 545 |
y_tail = to_target(tail.astype(np.float32))
|
| 546 |
Ltail = int(y_tail.shape[0])
|
| 547 |
if Ltail:
|
| 548 |
+
# DEBUG: tail RMS we are appending now (to be corrected next call)
|
| 549 |
+
print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
|
| 550 |
self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
|
| 551 |
self._spool_written += Ltail
|
| 552 |
self._pending_tail_model = tail.copy()
|
| 553 |
self._pending_tail_target_len = Ltail
|
| 554 |
else:
|
| 555 |
+
# Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
|
| 556 |
self._pending_tail_model = tail.copy()
|
| 557 |
self._pending_tail_target_len = 0
|
| 558 |
else:
|
|
|
|
| 560 |
self._pending_tail_target_len = 0
|
| 561 |
|
| 562 |
|
| 563 |
+
|
| 564 |
def _should_generate_next_chunk(self) -> bool:
|
| 565 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
| 566 |
# (explicit ack from client) or last *delivered* (implicit ack).
|
|
|
|
| 569 |
return self.idx <= (horizon_anchor + self._max_buffer_ahead)
|
| 570 |
|
| 571 |
def _emit_ready(self):
|
| 572 |
+
"""Emit next chunk(s) if the spool has enough samples."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
while True:
|
| 574 |
start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
|
| 575 |
if end > self._spool_written:
|
| 576 |
+
break # need more audio
|
|
|
|
| 577 |
loop = self._spool[start:end]
|
| 578 |
|
| 579 |
+
# Loudness match to reference loop (optional)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
if self.params.ref_loop is not None and self.params.loudness_mode != "none":
|
| 581 |
ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
|
| 582 |
wav = au.Waveform(loop.copy(), int(self.params.target_sr))
|
| 583 |
+
matched, _ = match_loudness_to_reference(ref, wav, method=self.params.loudness_mode, headroom_db=self.params.headroom_db)
|
| 584 |
+
loop = matched.samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
|
|
|
|
| 586 |
audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
|
| 587 |
meta = {
|
| 588 |
"bpm": float(self.params.bpm),
|
|
|
|
| 599 |
}
|
| 600 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 601 |
|
| 602 |
+
if os.getenv("MRT_DEBUG_RMS", "0") == "1":
|
| 603 |
+
spb = self._bar_clock.bar_samps
|
| 604 |
+
seg = int(max(1, spb // 4)) # quarter-bar window
|
| 605 |
+
|
| 606 |
+
rms = [float(np.sqrt(np.mean(loop[i:i+seg]**2))) for i in range(0, loop.shape[0], seg)]
|
| 607 |
+
print(f"[emit idx={self.idx}] quarter-bar RMS: {rms[:8]}")
|
| 608 |
+
|
| 609 |
with self._cv:
|
| 610 |
self._outbox[self.idx] = chunk
|
| 611 |
self._cv.notify_all()
|
|
|
|
|
|
|
| 612 |
self.idx += 1
|
| 613 |
|
| 614 |
+
# If a reseed is queued, install it *right after* we finish a chunk
|
| 615 |
with self._lock:
|
| 616 |
+
# Prefer seamless token splice when available
|
| 617 |
if self._pending_token_splice is not None:
|
| 618 |
spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
|
| 619 |
try:
|
| 620 |
+
# inplace update (no reset)
|
| 621 |
self.state.context_tokens = spliced
|
| 622 |
self._pending_token_splice = None
|
|
|
|
| 623 |
except Exception:
|
| 624 |
+
# fallback: full reseed using spliced tokens
|
| 625 |
new_state = self.mrt.init_state()
|
| 626 |
new_state.context_tokens = spliced
|
| 627 |
self.state = new_state
|
| 628 |
self._model_stream = None
|
| 629 |
self._pending_token_splice = None
|
|
|
|
| 630 |
elif self._pending_reseed is not None:
|
| 631 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 632 |
new_state = self.mrt.init_state()
|
|
|
|
| 634 |
self.state = new_state
|
| 635 |
self._model_stream = None
|
| 636 |
self._pending_reseed = None
|
|
|
|
|
|
|
| 637 |
|
| 638 |
# ---------- main loop ----------
|
| 639 |
|