Commit
·
af5a7b2
1
Parent(s):
9fb3c06
update styles endpoint
Browse files
app.py
CHANGED
|
@@ -17,6 +17,7 @@ from jam_worker import JamWorker, JamParams, JamChunk
|
|
| 17 |
import uuid, threading
|
| 18 |
|
| 19 |
import gradio as gr
|
|
|
|
| 20 |
|
| 21 |
def create_documentation_interface():
|
| 22 |
"""Create a Gradio interface for documentation and transparency"""
|
|
@@ -581,47 +582,60 @@ def jam_stop(session_id: str = Body(..., embed=True)):
|
|
| 581 |
jam_registry.pop(session_id, None)
|
| 582 |
return {"stopped": True}
|
| 583 |
|
| 584 |
-
@app.post("/jam/update")
|
| 585 |
-
def jam_update(
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 593 |
-
worker.update_knobs(guidance_weight=guidance_weight, temperature=temperature, topk=topk)
|
| 594 |
-
return {"ok": True}
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
with jam_lock:
|
| 603 |
worker = jam_registry.get(session_id)
|
| 604 |
if worker is None or not worker.is_alive():
|
| 605 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 606 |
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
sw = [float(x) for x in style_weights.split(",")] if style_weights else []
|
| 615 |
-
for i, s in enumerate(extra):
|
| 616 |
-
embeds.append(worker.mrt.embed_style(s.strip()))
|
| 617 |
-
weights.append(sw[i] if i < len(sw) else 1.0)
|
| 618 |
-
|
| 619 |
-
wsum = sum(weights) or 1.0
|
| 620 |
-
weights = [w/wsum for w in weights]
|
| 621 |
-
style_vec = np.sum([w*e for w,e in zip(weights, embeds)], axis=0).astype(np.float32)
|
| 622 |
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
return {"ok": True}
|
| 627 |
|
|
|
|
| 17 |
import uuid, threading
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
+
from typing import Optional
|
| 21 |
|
| 22 |
def create_documentation_interface():
|
| 23 |
"""Create a Gradio interface for documentation and transparency"""
|
|
|
|
| 582 |
jam_registry.pop(session_id, None)
|
| 583 |
return {"stopped": True}
|
| 584 |
|
| 585 |
+
@app.post("/jam/update") # consolidated
|
| 586 |
+
def jam_update(
|
| 587 |
+
session_id: str = Form(...),
|
| 588 |
+
|
| 589 |
+
# knobs (all optional)
|
| 590 |
+
guidance_weight: Optional[float] = Form(None),
|
| 591 |
+
temperature: Optional[float] = Form(None),
|
| 592 |
+
topk: Optional[int] = Form(None),
|
|
|
|
|
|
|
|
|
|
| 593 |
|
| 594 |
+
# styles (all optional)
|
| 595 |
+
styles: str = Form(""),
|
| 596 |
+
style_weights: str = Form(""),
|
| 597 |
+
loop_weight: Optional[float] = Form(None), # None means "don’t change"
|
| 598 |
+
use_current_mix_as_style: bool = Form(False),
|
| 599 |
+
):
|
| 600 |
with jam_lock:
|
| 601 |
worker = jam_registry.get(session_id)
|
| 602 |
if worker is None or not worker.is_alive():
|
| 603 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 604 |
|
| 605 |
+
# --- 1) Apply knob updates (atomic under lock)
|
| 606 |
+
if any(v is not None for v in (guidance_weight, temperature, topk)):
|
| 607 |
+
worker.update_knobs(
|
| 608 |
+
guidance_weight=guidance_weight,
|
| 609 |
+
temperature=temperature,
|
| 610 |
+
topk=topk
|
| 611 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
+
# --- 2) Apply style updates only if requested
|
| 614 |
+
wants_style_update = use_current_mix_as_style or (styles.strip() != "")
|
| 615 |
+
if wants_style_update:
|
| 616 |
+
embeds, weights = [], []
|
| 617 |
+
|
| 618 |
+
# optional: include current mix as a style component
|
| 619 |
+
if use_current_mix_as_style and worker.params.combined_loop is not None:
|
| 620 |
+
lw = 1.0 if loop_weight is None else float(loop_weight)
|
| 621 |
+
embeds.append(worker.mrt.embed_style(worker.params.combined_loop))
|
| 622 |
+
weights.append(lw)
|
| 623 |
+
|
| 624 |
+
# extra text styles
|
| 625 |
+
extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
|
| 626 |
+
sw = [float(x) for x in style_weights.split(",")] if style_weights else []
|
| 627 |
+
for i, s in enumerate(extra):
|
| 628 |
+
embeds.append(worker.mrt.embed_style(s.strip()))
|
| 629 |
+
weights.append(sw[i] if i < len(sw) else 1.0)
|
| 630 |
+
|
| 631 |
+
if embeds: # only swap if we actually built something
|
| 632 |
+
wsum = sum(weights) or 1.0
|
| 633 |
+
weights = [w / wsum for w in weights]
|
| 634 |
+
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
|
| 635 |
+
|
| 636 |
+
# install atomically
|
| 637 |
+
with worker._lock:
|
| 638 |
+
worker.params.style_vec = style_vec
|
| 639 |
|
| 640 |
return {"ok": True}
|
| 641 |
|