thecollabagepatch commited on
Commit
c4a08cc
Β·
1 Parent(s): 9f21d72

buffer queue customization added to websockets route

Browse files
Files changed (2) hide show
  1. app.py +117 -10
  2. magentaRT_rt_tester.html +49 -2
app.py CHANGED
@@ -1647,16 +1647,37 @@ async def ws_jam(websocket: WebSocket):
1647
 
1648
  # kick off the ~2s streaming loop
1649
  async def _rt_loop():
 
 
 
 
1650
  try:
1651
  mrt = websocket._mrt
1652
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1653
  target_next = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1654
  while websocket._rt_running:
 
1655
  mrt.guidance_weight = websocket._rt_guid
1656
  mrt.temperature = websocket._rt_temp
1657
  mrt.topk = websocket._rt_topk
1658
 
1659
- # ramp style
1660
  ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
1661
  if ramp <= 0.0:
1662
  websocket._style_cur = websocket._style_tgt
@@ -1664,38 +1685,100 @@ async def ws_jam(websocket: WebSocket):
1664
  step = min(1.0, chunk_secs / ramp)
1665
  websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
1666
 
 
 
1667
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
1668
  websocket._state = new_state
1669
-
 
 
1670
  x = wav.samples.astype(np.float32, copy=False)
1671
  buf = io.BytesIO()
1672
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1673
 
 
1674
  ok = True
 
1675
  if binary_audio:
1676
  try:
1677
  await websocket.send_bytes(buf.getvalue())
1678
- ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
 
 
 
 
 
 
1679
  except Exception:
1680
  ok = False
1681
  else:
1682
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1683
- ok = await send_json({"type": "chunk", "audio_base64": b64,
1684
- "metadata": {"sample_rate": mrt.sample_rate}})
 
 
 
 
 
 
 
1685
 
1686
  if not ok:
1687
  break
1688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1689
  if getattr(websocket, "_pace", "asap") == "realtime":
1690
- t1 = time.perf_counter()
1691
  target_next += chunk_secs
1692
- sleep_s = max(0.0, target_next - t1 - 0.03)
 
1693
  if sleep_s > 0:
1694
  await asyncio.sleep(sleep_s)
 
 
 
 
 
 
 
 
 
1695
  except asyncio.CancelledError:
1696
  pass
1697
- except Exception:
1698
- pass
 
 
1699
 
1700
  websocket._rt_task = asyncio.create_task(_rt_loop())
1701
  continue # skip the β€œbar-mode started” message below
@@ -1737,11 +1820,18 @@ async def ws_jam(websocket: WebSocket):
1737
  )
1738
  await send_json({"type":"status", **res}) # {"ok": True}
1739
  else:
1740
- # rt-mode: there’s no JamWorker; update the local knobs/state
1741
  websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1742
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1743
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1744
 
 
 
 
 
 
 
 
1745
  # NEW steering fields
1746
  if "mean" in msg and msg["mean"] is not None:
1747
  try: websocket._rt_mean = float(msg["mean"])
@@ -1761,6 +1851,7 @@ async def ws_jam(websocket: WebSocket):
1761
  text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1762
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1763
 
 
1764
  asset_manager.ensure_assets_loaded(get_mrt())
1765
  websocket._style_tgt = build_style_vector(
1766
  websocket._mrt,
@@ -1771,12 +1862,28 @@ async def ws_jam(websocket: WebSocket):
1771
  mean_weight=float(websocket._rt_mean),
1772
  centroid_weights=websocket._rt_centroid_weights,
1773
  )
 
 
 
 
1774
  # optionally allow live changes to ramp:
1775
  if "style_ramp_seconds" in msg:
1776
  try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
1777
  except: pass
 
1778
  await send_json({"type":"status","updated":"rt-knobs+style"})
1779
 
 
 
 
 
 
 
 
 
 
 
 
1780
  elif mtype == "consume" and mode == "bar":
1781
  with jam_lock:
1782
  worker = jam_registry.get(msg.get("session_id"))
 
1647
 
1648
  # kick off the ~2s streaming loop
1649
  async def _rt_loop():
1650
+ """
1651
+ Enhanced realtime generation loop with adaptive pacing.
1652
+ Prevents buffer underruns while keeping style updates responsive.
1653
+ """
1654
  try:
1655
  mrt = websocket._mrt
1656
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1657
  target_next = time.perf_counter()
1658
+
1659
+ # ADAPTIVE PACING STATE
1660
+ # These thresholds define when to speed up or slow down
1661
+ BUFFER_LOW_THRESHOLD = 5.0 # Speed up if buffer < 3s
1662
+ BUFFER_TARGET = 5.0 # Target buffer level
1663
+ BUFFER_HIGH_THRESHOLD = 8.0 # Slow down if buffer > 7s
1664
+ BURST_CHUNKS = 3 # Number of chunks to burst ahead after updates
1665
+
1666
+ # Pacing lookahead values (how far ahead to stay)
1667
+ LOOKAHEAD_BURST = 0.0 # No sleep during burst (go as fast as possible)
1668
+ LOOKAHEAD_NORMAL = 0.02 # Normal realtime (your current value)
1669
+ LOOKAHEAD_SLOW = 0.10 # When buffer is high, can afford more latency
1670
+
1671
+ burst_countdown = 2 # Chunks remaining in burst mode
1672
+ last_buffer_level = BUFFER_TARGET # Start assuming target
1673
+
1674
  while websocket._rt_running:
1675
+ # Update model parameters (these are fast, just attribute assignments)
1676
  mrt.guidance_weight = websocket._rt_guid
1677
  mrt.temperature = websocket._rt_temp
1678
  mrt.topk = websocket._rt_topk
1679
 
1680
+ # Ramp style vector (already implemented, keep as-is)
1681
  ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
1682
  if ramp <= 0.0:
1683
  websocket._style_cur = websocket._style_tgt
 
1685
  step = min(1.0, chunk_secs / ramp)
1686
  websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
1687
 
1688
+ # GENERATE CHUNK (this is the heavy operation)
1689
+ t_gen_start = time.perf_counter()
1690
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
1691
  websocket._state = new_state
1692
+ t_gen_end = time.perf_counter()
1693
+
1694
+ # Encode audio
1695
  x = wav.samples.astype(np.float32, copy=False)
1696
  buf = io.BytesIO()
1697
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1698
 
1699
+ # Send to client
1700
  ok = True
1701
+ t_send_start = time.perf_counter()
1702
  if binary_audio:
1703
  try:
1704
  await websocket.send_bytes(buf.getvalue())
1705
+ ok = await send_json({
1706
+ "type": "chunk_meta",
1707
+ "metadata": {
1708
+ "sample_rate": mrt.sample_rate,
1709
+ "generation_time_ms": int((t_gen_end - t_gen_start) * 1000),
1710
+ }
1711
+ })
1712
  except Exception:
1713
  ok = False
1714
  else:
1715
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1716
+ ok = await send_json({
1717
+ "type": "chunk",
1718
+ "audio_base64": b64,
1719
+ "metadata": {
1720
+ "sample_rate": mrt.sample_rate,
1721
+ "generation_time_ms": int((t_gen_end - t_gen_start) * 1000),
1722
+ }
1723
+ })
1724
+ t_send_end = time.perf_counter()
1725
 
1726
  if not ok:
1727
  break
1728
 
1729
+ # ADAPTIVE PACING LOGIC
1730
+ # Read buffer level from websocket attribute (updated by frontend in update messages)
1731
+ current_buffer = getattr(websocket, "_frontend_buffer_seconds", last_buffer_level)
1732
+ last_buffer_level = current_buffer
1733
+
1734
+ # Check if we received an update signal (set by the update message handler)
1735
+ if getattr(websocket, "_rt_update_received", False):
1736
+ # Burst ahead for a few chunks to rebuild buffer after style change
1737
+ burst_countdown = BURST_CHUNKS
1738
+ websocket._rt_update_received = False
1739
+
1740
+ # Determine pacing mode based on buffer level and burst state
1741
+ if burst_countdown > 0:
1742
+ # BURST MODE: Go as fast as possible
1743
+ lookahead = LOOKAHEAD_BURST
1744
+ burst_countdown -= 1
1745
+ pacing_mode = "burst"
1746
+ elif current_buffer < BUFFER_LOW_THRESHOLD:
1747
+ # LOW BUFFER: Speed up
1748
+ lookahead = LOOKAHEAD_BURST # No sleep, catch up
1749
+ pacing_mode = "catching_up"
1750
+ elif current_buffer > BUFFER_HIGH_THRESHOLD:
1751
+ # HIGH BUFFER: Can afford to slow down
1752
+ lookahead = LOOKAHEAD_SLOW
1753
+ pacing_mode = "relaxed"
1754
+ else:
1755
+ # NORMAL: Target buffer range
1756
+ lookahead = LOOKAHEAD_NORMAL
1757
+ pacing_mode = "normal"
1758
+
1759
+ # Apply pacing only if not in "asap" mode
1760
  if getattr(websocket, "_pace", "asap") == "realtime":
1761
+ t_now = time.perf_counter()
1762
  target_next += chunk_secs
1763
+ sleep_s = max(0.0, target_next - t_now - lookahead)
1764
+
1765
  if sleep_s > 0:
1766
  await asyncio.sleep(sleep_s)
1767
+
1768
+ # Debug logging (can be removed in production)
1769
+ gen_ms = int((t_gen_end - t_gen_start) * 1000)
1770
+ send_ms = int((t_send_end - t_send_start) * 1000)
1771
+ print(f"[RT] buffer:{current_buffer:.1f}s mode:{pacing_mode} gen:{gen_ms}ms send:{send_ms}ms sleep:{int(sleep_s*1000)}ms")
1772
+ else:
1773
+ # ASAP mode: don't sleep at all
1774
+ pass
1775
+
1776
  except asyncio.CancelledError:
1777
  pass
1778
+ except Exception as e:
1779
+ print(f"[RT] generation error: {e}")
1780
+ import traceback
1781
+ traceback.print_exc()
1782
 
1783
  websocket._rt_task = asyncio.create_task(_rt_loop())
1784
  continue # skip the β€œbar-mode started” message below
 
1820
  )
1821
  await send_json({"type":"status", **res}) # {"ok": True}
1822
  else:
1823
+ # rt-mode: update knobs and style
1824
  websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1825
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1826
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1827
 
1828
+ # NEW: Read frontend buffer level from update message
1829
+ if "frontend_buffer_seconds" in msg:
1830
+ try:
1831
+ websocket._frontend_buffer_seconds = float(msg["frontend_buffer_seconds"])
1832
+ except:
1833
+ pass
1834
+
1835
  # NEW steering fields
1836
  if "mean" in msg and msg["mean"] is not None:
1837
  try: websocket._rt_mean = float(msg["mean"])
 
1851
  text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1852
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1853
 
1854
+ # Build style vector (this can take 50-200ms)
1855
  asset_manager.ensure_assets_loaded(get_mrt())
1856
  websocket._style_tgt = build_style_vector(
1857
  websocket._mrt,
 
1862
  mean_weight=float(websocket._rt_mean),
1863
  centroid_weights=websocket._rt_centroid_weights,
1864
  )
1865
+
1866
+ # Signal to generation loop that update occurred (trigger burst mode)
1867
+ websocket._rt_update_received = True
1868
+
1869
  # optionally allow live changes to ramp:
1870
  if "style_ramp_seconds" in msg:
1871
  try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
1872
  except: pass
1873
+
1874
  await send_json({"type":"status","updated":"rt-knobs+style"})
1875
 
1876
+ elif mtype == "buffer_status":
1877
+ # Frontend reporting its buffer level for adaptive pacing
1878
+ if "frontend_buffer_seconds" in msg:
1879
+ try:
1880
+ websocket._frontend_buffer_seconds = float(msg["frontend_buffer_seconds"])
1881
+ # Optional: log for monitoring
1882
+ # print(f"[RT] frontend buffer: {websocket._frontend_buffer_seconds:.1f}s")
1883
+ except:
1884
+ pass
1885
+ # No response needed, this is just status info
1886
+
1887
  elif mtype == "consume" and mode == "bar":
1888
  with jam_lock:
1889
  worker = jam_registry.get(msg.get("session_id"))
magentaRT_rt_tester.html CHANGED
@@ -353,6 +353,7 @@ function beginPlaybackFromPending() {
353
  let ws = null;
354
  let connected = false;
355
  let autoUpdateTimer = null;
 
356
 
357
  /**
358
  * Push a line into the log ring and schedule a single repaint via rAF.
@@ -391,7 +392,20 @@ function beginPlaybackFromPending() {
391
 
392
  function updateQueueUI() {
393
  const total = scheduled.reduce((acc, s) => acc + s.dur, 0);
394
- queueEl.textContent = `${scheduled.length} buffers, ${total.toFixed(2)}s scheduled`;
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  }
396
 
397
  function clearSchedule() {
@@ -452,7 +466,11 @@ async function scheduleWavBytes(arrayBuffer) {
452
 
453
  function sendUpdate() {
454
  if (!ws || ws.readyState !== 1) return;
455
- const msg = { type: "update", ...currentParams() };
 
 
 
 
456
  ws.send(JSON.stringify(msg));
457
  log("β†’ update " + JSON.stringify(msg), "small");
458
  }
@@ -463,6 +481,30 @@ async function scheduleWavBytes(arrayBuffer) {
463
  autoUpdateTimer = setTimeout(sendUpdate, 150);
464
  }
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  function linkRangeNumber(range, number, cb) {
467
  const sync = (fromRange) => {
468
  if (fromRange) number.value = range.value;
@@ -628,6 +670,9 @@ async function scheduleWavBytes(arrayBuffer) {
628
  ws.send(JSON.stringify(msg));
629
  log("β†’ start " + JSON.stringify(msg), "ok");
630
  nextTime = ctx.currentTime + 0.12;
 
 
 
631
  };
632
 
633
  ws.onmessage = async (ev) => {
@@ -682,6 +727,7 @@ async function scheduleWavBytes(arrayBuffer) {
682
  btnStop.disabled = true;
683
  setStatus("closed");
684
  log("connection closed", "warn");
 
685
  };
686
 
687
  ws.onerror = (e) => {
@@ -691,6 +737,7 @@ async function scheduleWavBytes(arrayBuffer) {
691
 
692
  function stop() {
693
  if (!connected) return;
 
694
  try {
695
  ws?.send(JSON.stringify({ type: "stop" }));
696
  } catch {}
 
353
  let ws = null;
354
  let connected = false;
355
  let autoUpdateTimer = null;
356
+ let bufferStatusInterval = null; // NEW: For periodic buffer status reporting
357
 
358
  /**
359
  * Push a line into the log ring and schedule a single repaint via rAF.
 
392
 
393
  function updateQueueUI() {
394
  const total = scheduled.reduce((acc, s) => acc + s.dur, 0);
395
+ const bufferLevel = getBufferLevel();
396
+ const bufferStatus =
397
+ bufferLevel < 2.0 ? 'πŸ”΄ CRITICAL' :
398
+ bufferLevel < 3.0 ? '🟑 LOW' :
399
+ bufferLevel < 5.0 ? '🟒 GOOD' :
400
+ 'πŸ”΅ HIGH';
401
+ queueEl.textContent = `${scheduled.length} buffers, ${total.toFixed(2)}s scheduled | Buffer: ${bufferLevel.toFixed(2)}s ${bufferStatus}`;
402
+ }
403
+
404
+ function getBufferLevel() {
405
+ if (!ctx || !playing) return 0;
406
+ const currentTime = ctx.currentTime;
407
+ const bufferSeconds = Math.max(0, nextTime - currentTime);
408
+ return bufferSeconds;
409
  }
410
 
411
  function clearSchedule() {
 
466
 
467
  function sendUpdate() {
468
  if (!ws || ws.readyState !== 1) return;
469
+ const msg = {
470
+ type: "update",
471
+ ...currentParams(),
472
+ frontend_buffer_seconds: getBufferLevel() // Include buffer level for adaptive pacing
473
+ };
474
  ws.send(JSON.stringify(msg));
475
  log("β†’ update " + JSON.stringify(msg), "small");
476
  }
 
481
  autoUpdateTimer = setTimeout(sendUpdate, 150);
482
  }
483
 
484
+ function startBufferStatusReporting() {
485
+ stopBufferStatusReporting();
486
+ bufferStatusInterval = setInterval(() => {
487
+ if (ws && ws.readyState === 1 && connected && playing) {
488
+ const frontend_buffer_seconds = getBufferLevel();
489
+ // Send status when buffer is low or periodically (every 2 seconds)
490
+ const now = Date.now();
491
+ if (frontend_buffer_seconds < 4.0 || now % 2000 < 500) {
492
+ ws.send(JSON.stringify({
493
+ type: 'buffer_status',
494
+ frontend_buffer_seconds
495
+ }));
496
+ }
497
+ }
498
+ }, 500);
499
+ }
500
+
501
+ function stopBufferStatusReporting() {
502
+ if (bufferStatusInterval !== null) {
503
+ clearInterval(bufferStatusInterval);
504
+ bufferStatusInterval = null;
505
+ }
506
+ }
507
+
508
  function linkRangeNumber(range, number, cb) {
509
  const sync = (fromRange) => {
510
  if (fromRange) number.value = range.value;
 
670
  ws.send(JSON.stringify(msg));
671
  log("β†’ start " + JSON.stringify(msg), "ok");
672
  nextTime = ctx.currentTime + 0.12;
673
+
674
+ // Start buffer status reporting for adaptive pacing
675
+ startBufferStatusReporting();
676
  };
677
 
678
  ws.onmessage = async (ev) => {
 
727
  btnStop.disabled = true;
728
  setStatus("closed");
729
  log("connection closed", "warn");
730
+ stopBufferStatusReporting(); // Stop buffer reporting
731
  };
732
 
733
  ws.onerror = (e) => {
 
737
 
738
  function stop() {
739
  if (!connected) return;
740
+ stopBufferStatusReporting(); // Stop buffer reporting
741
  try {
742
  ws?.send(JSON.stringify({ type: "stop" }));
743
  } catch {}