thecollabagepatch commited on
Commit
ca4c65f
·
1 Parent(s): d96e1a0

it wasn't the loudness matching. _append_model_chunk_and_spool debug now

Browse files
Files changed (1) hide show
  1. jam_worker.py +38 -94
jam_worker.py CHANGED
@@ -19,6 +19,13 @@ from utils import (
19
  wav_bytes_base64,
20
  )
21
 
 
 
 
 
 
 
 
22
  # -----------------------------
23
  # Data classes
24
  # -----------------------------
@@ -430,14 +437,13 @@ class JamWorker(threading.Thread):
430
  Conservative boundary fix:
431
  - Emit body+tail immediately (target SR), unchanged from your original behavior.
432
  - On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
433
- resample it, and overwrite the last `_pending_tail_target_len` samples in the
434
- target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
435
- remember THIS chunk's tail length at target SR for the next correction.
436
 
437
  This keeps external timing and bar alignment identical, but removes the audible
438
  fade-to-zero at chunk ends.
439
  """
440
-
441
 
442
  # ---- unpack model-rate samples ----
443
  s = wav.samples.astype(np.float32, copy=False)
@@ -471,6 +477,10 @@ class JamWorker(threading.Thread):
471
  y_mixed = to_target(mixed_model.astype(np.float32))
472
  Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
473
 
 
 
 
 
474
  # Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
475
  # Use the *smaller* of the two lengths to be safe.
476
  Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
@@ -510,6 +520,8 @@ class JamWorker(threading.Thread):
510
  if body.size:
511
  y_body = to_target(body.astype(np.float32))
512
  if y_body.size:
 
 
513
  self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
514
  self._spool_written += y_body.shape[0]
515
  else:
@@ -518,6 +530,8 @@ class JamWorker(threading.Thread):
518
  body = s[xfade_n:, :]
519
  y_body = to_target(body.astype(np.float32))
520
  if y_body.size:
 
 
521
  self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
522
  self._spool_written += y_body.shape[0]
523
  # No tail to remember this round
@@ -531,12 +545,14 @@ class JamWorker(threading.Thread):
531
  y_tail = to_target(tail.astype(np.float32))
532
  Ltail = int(y_tail.shape[0])
533
  if Ltail:
 
 
534
  self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
535
  self._spool_written += Ltail
536
  self._pending_tail_model = tail.copy()
537
  self._pending_tail_target_len = Ltail
538
  else:
539
- # Nothing appended (resampler returning nothing yet) — keep model tail but mark zero target len
540
  self._pending_tail_model = tail.copy()
541
  self._pending_tail_target_len = 0
542
  else:
@@ -544,6 +560,7 @@ class JamWorker(threading.Thread):
544
  self._pending_tail_target_len = 0
545
 
546
 
 
547
  def _should_generate_next_chunk(self) -> bool:
548
  # Allow running ahead relative to whichever is larger: last *consumed*
549
  # (explicit ack from client) or last *delivered* (implicit ack).
@@ -552,97 +569,20 @@ class JamWorker(threading.Thread):
552
  return self.idx <= (horizon_anchor + self._max_buffer_ahead)
553
 
554
  def _emit_ready(self):
555
- """Emit next chunk(s) if the spool has enough samples. With robust RMS debug."""
556
-
557
-
558
- QDB_SILENCE = -55.0
559
- EPS = 1e-12
560
-
561
- def rms_dbfs(x: np.ndarray) -> float:
562
- if x.ndim == 2:
563
- x = x.mean(axis=1)
564
- rms = float(np.sqrt(np.mean(np.square(x)) + EPS))
565
- return 20.0 * np.log10(max(rms, EPS))
566
-
567
- def qbar_rms_dbfs(x: np.ndarray, seg_len: int) -> list[float]:
568
- if x.ndim == 2:
569
- mono = x.mean(axis=1)
570
- else:
571
- mono = x
572
- N = mono.shape[0]
573
- vals = []
574
- for i in range(0, N, seg_len):
575
- seg = mono[i:min(i + seg_len, N)]
576
- if seg.size == 0:
577
- break
578
- r = float(np.sqrt(np.mean(seg * seg) + EPS))
579
- vals.append(20.0 * np.log10(max(r, EPS)))
580
- return vals
581
-
582
- def fmt_db_list(vals):
583
- return ['%5.1f' % v for v in vals[:8]]
584
-
585
- def extract_gain_db(g):
586
- # Accept float/int, dict{'gain_db': ...}, tuple/list, or None
587
- if g is None:
588
- return None
589
- if isinstance(g, (int, float)):
590
- return float(g)
591
- if isinstance(g, dict):
592
- for k in ('gain_db', 'gain', 'applied_gain_db'):
593
- if k in g:
594
- try:
595
- return float(g[k])
596
- except Exception:
597
- pass
598
- return None
599
- if isinstance(g, (list, tuple)) and g:
600
- try:
601
- return float(g[0])
602
- except Exception:
603
- return None
604
- return None
605
-
606
  while True:
607
  start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
608
  if end > self._spool_written:
609
- break
610
-
611
  loop = self._spool[start:end]
612
 
613
- # ---- pre-LM diagnostics ----
614
- spb = self._bar_clock.bar_samps
615
- qlen = max(1, spb // 4)
616
- q_rms_pre = qbar_rms_dbfs(loop, qlen)
617
- silent_marks_pre = ["🟢" if v > QDB_SILENCE else "🟥" for v in q_rms_pre[:8]]
618
- print(f"[emit idx={self.idx}] pre-LM qRMS dBFS: {fmt_db_list(q_rms_pre)} {''.join(silent_marks_pre)}")
619
-
620
- # Loudness match (optional)
621
- gain_db_applied_raw = None
622
  if self.params.ref_loop is not None and self.params.loudness_mode != "none":
623
  ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
624
  wav = au.Waveform(loop.copy(), int(self.params.target_sr))
625
- try:
626
- matched, gain_db_applied_raw = match_loudness_to_reference(
627
- ref, wav,
628
- method=self.params.loudness_mode,
629
- headroom_db=self.params.headroom_db
630
- )
631
- loop = matched.samples
632
- except Exception as e:
633
- print(f"[emit idx={self.idx}] loudness-match ERROR: {e}; proceeding with un-matched audio")
634
-
635
- gain_db = extract_gain_db(gain_db_applied_raw)
636
-
637
- # ---- post-LM diagnostics ----
638
- q_rms_post = qbar_rms_dbfs(loop, qlen)
639
- silent_marks_post = ["🟢" if v > QDB_SILENCE else "🟥" for v in q_rms_post[:8]]
640
- if gain_db is None:
641
- print(f"[emit idx={self.idx}] post-LM qRMS dBFS: {fmt_db_list(q_rms_post)} {''.join(silent_marks_post)} (LM: none)")
642
- else:
643
- print(f"[emit idx={self.idx}] post-LM qRMS dBFS: {fmt_db_list(q_rms_post)} {''.join(silent_marks_post)} (LM gain {gain_db:+.2f} dB)")
644
 
645
- # Encode & ship
646
  audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
647
  meta = {
648
  "bpm": float(self.params.bpm),
@@ -659,28 +599,34 @@ class JamWorker(threading.Thread):
659
  }
660
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
661
 
 
 
 
 
 
 
 
662
  with self._cv:
663
  self._outbox[self.idx] = chunk
664
  self._cv.notify_all()
665
-
666
- print(f"[emit idx={self.idx}] slice [{start}:{end}] (len={end-start}), spool_written={self._spool_written}")
667
  self.idx += 1
668
 
669
- # Apply pending splices/reseeds immediately after a completed emit
670
  with self._lock:
 
671
  if self._pending_token_splice is not None:
672
  spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
673
  try:
 
674
  self.state.context_tokens = spliced
675
  self._pending_token_splice = None
676
- print(f"[emit idx={self.idx}] installed token splice (in-place)")
677
  except Exception:
 
678
  new_state = self.mrt.init_state()
679
  new_state.context_tokens = spliced
680
  self.state = new_state
681
  self._model_stream = None
682
  self._pending_token_splice = None
683
- print(f"[emit idx={self.idx}] installed token splice (reinit state)")
684
  elif self._pending_reseed is not None:
685
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
686
  new_state = self.mrt.init_state()
@@ -688,8 +634,6 @@ class JamWorker(threading.Thread):
688
  self.state = new_state
689
  self._model_stream = None
690
  self._pending_reseed = None
691
- print(f"[emit idx={self.idx}] performed full reseed")
692
-
693
 
694
  # ---------- main loop ----------
695
 
 
19
  wav_bytes_base64,
20
  )
21
 
22
+ def _dbg_rms_dbfs(x: np.ndarray) -> float:
23
+
24
+ if x.ndim == 2:
25
+ x = x.mean(axis=1)
26
+ r = float(np.sqrt(np.mean(x * x) + 1e-12))
27
+ return 20.0 * np.log10(max(r, 1e-12))
28
+
29
  # -----------------------------
30
  # Data classes
31
  # -----------------------------
 
437
  Conservative boundary fix:
438
  - Emit body+tail immediately (target SR), unchanged from your original behavior.
439
  - On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
440
+ resample it, and overwrite the last `_pending_tail_target_len` samples in the
441
+ target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
442
+ remember THIS chunk's tail length at target SR for the next correction.
443
 
444
  This keeps external timing and bar alignment identical, but removes the audible
445
  fade-to-zero at chunk ends.
446
  """
 
447
 
448
  # ---- unpack model-rate samples ----
449
  s = wav.samples.astype(np.float32, copy=False)
 
477
  y_mixed = to_target(mixed_model.astype(np.float32))
478
  Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
479
 
480
+ # DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
481
+ if y_mixed.size:
482
+ print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
483
+
484
  # Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
485
  # Use the *smaller* of the two lengths to be safe.
486
  Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
 
520
  if body.size:
521
  y_body = to_target(body.astype(np.float32))
522
  if y_body.size:
523
+ # DEBUG: body RMS we are actually appending
524
+ print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
525
  self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
526
  self._spool_written += y_body.shape[0]
527
  else:
 
530
  body = s[xfade_n:, :]
531
  y_body = to_target(body.astype(np.float32))
532
  if y_body.size:
533
+ # DEBUG: body RMS in short-chunk path
534
+ print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
535
  self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
536
  self._spool_written += y_body.shape[0]
537
  # No tail to remember this round
 
545
  y_tail = to_target(tail.astype(np.float32))
546
  Ltail = int(y_tail.shape[0])
547
  if Ltail:
548
+ # DEBUG: tail RMS we are appending now (to be corrected next call)
549
+ print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
550
  self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
551
  self._spool_written += Ltail
552
  self._pending_tail_model = tail.copy()
553
  self._pending_tail_target_len = Ltail
554
  else:
555
+ # Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
556
  self._pending_tail_model = tail.copy()
557
  self._pending_tail_target_len = 0
558
  else:
 
560
  self._pending_tail_target_len = 0
561
 
562
 
563
+
564
  def _should_generate_next_chunk(self) -> bool:
565
  # Allow running ahead relative to whichever is larger: last *consumed*
566
  # (explicit ack from client) or last *delivered* (implicit ack).
 
569
  return self.idx <= (horizon_anchor + self._max_buffer_ahead)
570
 
571
  def _emit_ready(self):
572
+ """Emit next chunk(s) if the spool has enough samples."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  while True:
574
  start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
575
  if end > self._spool_written:
576
+ break # need more audio
 
577
  loop = self._spool[start:end]
578
 
579
+ # Loudness match to reference loop (optional)
 
 
 
 
 
 
 
 
580
  if self.params.ref_loop is not None and self.params.loudness_mode != "none":
581
  ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
582
  wav = au.Waveform(loop.copy(), int(self.params.target_sr))
583
+ matched, _ = match_loudness_to_reference(ref, wav, method=self.params.loudness_mode, headroom_db=self.params.headroom_db)
584
+ loop = matched.samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
 
586
  audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
587
  meta = {
588
  "bpm": float(self.params.bpm),
 
599
  }
600
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
601
 
602
+ if os.getenv("MRT_DEBUG_RMS", "0") == "1":
603
+ spb = self._bar_clock.bar_samps
604
+ seg = int(max(1, spb // 4)) # quarter-bar window
605
+
606
+ rms = [float(np.sqrt(np.mean(loop[i:i+seg]**2))) for i in range(0, loop.shape[0], seg)]
607
+ print(f"[emit idx={self.idx}] quarter-bar RMS: {rms[:8]}")
608
+
609
  with self._cv:
610
  self._outbox[self.idx] = chunk
611
  self._cv.notify_all()
 
 
612
  self.idx += 1
613
 
614
+ # If a reseed is queued, install it *right after* we finish a chunk
615
  with self._lock:
616
+ # Prefer seamless token splice when available
617
  if self._pending_token_splice is not None:
618
  spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
619
  try:
620
+ # inplace update (no reset)
621
  self.state.context_tokens = spliced
622
  self._pending_token_splice = None
 
623
  except Exception:
624
+ # fallback: full reseed using spliced tokens
625
  new_state = self.mrt.init_state()
626
  new_state.context_tokens = spliced
627
  self.state = new_state
628
  self._model_stream = None
629
  self._pending_token_splice = None
 
630
  elif self._pending_reseed is not None:
631
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
632
  new_state = self.mrt.init_state()
 
634
  self.state = new_state
635
  self._model_stream = None
636
  self._pending_reseed = None
 
 
637
 
638
  # ---------- main loop ----------
639