Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
aecac1f
1
Parent(s):
87e5c0e
Use session state
Browse files
app.py
CHANGED
|
@@ -129,7 +129,7 @@ class AppState:
|
|
| 129 |
self.inference_session = None
|
| 130 |
self.model: Optional[Sam2VideoModel] = None
|
| 131 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 132 |
-
self.device: str = "
|
| 133 |
self.dtype: torch.dtype = torch.bfloat16
|
| 134 |
self.video_fps: float | None = None
|
| 135 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
@@ -158,9 +158,6 @@ class AppState:
|
|
| 158 |
return len(self.video_frames)
|
| 159 |
|
| 160 |
|
| 161 |
-
GLOBAL_STATE = AppState()
|
| 162 |
-
|
| 163 |
-
|
| 164 |
def _model_repo_from_key(key: str) -> str:
|
| 165 |
mapping = {
|
| 166 |
"tiny": "yonigozlan/sam2.1_hiera_tiny_hf",
|
|
@@ -171,7 +168,7 @@ def _model_repo_from_key(key: str) -> str:
|
|
| 171 |
return mapping.get(key, mapping["base_plus"])
|
| 172 |
|
| 173 |
|
| 174 |
-
def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
|
| 175 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 176 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
| 177 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
|
@@ -189,11 +186,13 @@ def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, tor
|
|
| 189 |
GLOBAL_STATE.processor = None
|
| 190 |
print(f"Loading model from {desired_repo}")
|
| 191 |
device, dtype = get_device_and_dtype()
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
| 194 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 195 |
-
|
| 196 |
-
model.to(device)
|
| 197 |
|
| 198 |
GLOBAL_STATE.model = model
|
| 199 |
GLOBAL_STATE.processor = processor
|
|
@@ -204,11 +203,11 @@ def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, tor
|
|
| 204 |
return model, processor, device, dtype
|
| 205 |
|
| 206 |
|
| 207 |
-
def ensure_session_for_current_model() -> None:
|
| 208 |
"""Ensure the model/processor match the selected repo and inference_session exists.
|
| 209 |
If a video is already loaded, re-initialize the inference session when needed.
|
| 210 |
"""
|
| 211 |
-
model, processor, device, dtype = load_model_if_needed()
|
| 212 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 213 |
if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
|
| 214 |
if GLOBAL_STATE.video_frames:
|
|
@@ -239,7 +238,7 @@ def ensure_session_for_current_model() -> None:
|
|
| 239 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 240 |
|
| 241 |
|
| 242 |
-
def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
|
| 243 |
"""Gradio handler: load video, init session, return state, slider bounds, and first frame."""
|
| 244 |
# Reset ONLY video-related fields, keep model loaded
|
| 245 |
GLOBAL_STATE.video_frames = []
|
|
@@ -247,7 +246,7 @@ def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Ima
|
|
| 247 |
GLOBAL_STATE.masks_by_frame = {}
|
| 248 |
GLOBAL_STATE.color_by_obj = {}
|
| 249 |
|
| 250 |
-
model, processor, device, dtype = load_model_if_needed()
|
| 251 |
|
| 252 |
# Gradio Video may provide a dict with 'name' or a direct file path
|
| 253 |
video_path: Optional[str] = None
|
|
@@ -349,9 +348,9 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
|
|
| 349 |
return compose_frame(state, frame_idx)
|
| 350 |
|
| 351 |
|
| 352 |
-
def _ensure_color_for_obj(obj_id: int):
|
| 353 |
-
if obj_id not in
|
| 354 |
-
|
| 355 |
|
| 356 |
|
| 357 |
def on_image_click(
|
|
@@ -384,20 +383,19 @@ def on_image_click(
|
|
| 384 |
if x is None or y is None:
|
| 385 |
raise gr.Error("Could not read click coordinates.")
|
| 386 |
|
| 387 |
-
_ensure_color_for_obj(int(obj_id))
|
| 388 |
|
| 389 |
-
processor =
|
| 390 |
-
model =
|
| 391 |
-
inference_session =
|
| 392 |
|
| 393 |
if state.current_prompt_type == "Boxes":
|
| 394 |
# Two-click box input
|
| 395 |
if state.pending_box_start is None:
|
| 396 |
-
#
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
state.composited_frames.pop(int(frame_idx), None)
|
| 401 |
state.pending_box_start = (int(x), int(y))
|
| 402 |
state.pending_box_start_frame_idx = int(frame_idx)
|
| 403 |
state.pending_box_start_obj_id = int(obj_id)
|
|
@@ -420,13 +418,13 @@ def on_image_click(
|
|
| 420 |
frame_idx=int(frame_idx),
|
| 421 |
obj_ids=int(obj_id),
|
| 422 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 423 |
-
clear_old_inputs=
|
| 424 |
)
|
| 425 |
|
| 426 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
| 427 |
obj_boxes = frame_boxes.setdefault(int(obj_id), [])
|
| 428 |
-
|
| 429 |
-
|
| 430 |
obj_boxes.append((x_min, y_min, x_max, y_max))
|
| 431 |
state.composited_frames.pop(int(frame_idx), None)
|
| 432 |
else:
|
|
@@ -454,8 +452,8 @@ def on_image_click(
|
|
| 454 |
state.composited_frames.pop(int(frame_idx), None)
|
| 455 |
|
| 456 |
# Forward on that frame
|
| 457 |
-
device_type = "cuda" if
|
| 458 |
-
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=
|
| 459 |
outputs = model(
|
| 460 |
inference_session=inference_session,
|
| 461 |
frame_idx=int(frame_idx),
|
|
@@ -477,17 +475,17 @@ def on_image_click(
|
|
| 477 |
mask_2d = mask_i.cpu().numpy().squeeze()
|
| 478 |
masks_for_frame[int(oid)] = mask_2d
|
| 479 |
|
| 480 |
-
|
| 481 |
# Invalidate cache for this frame to force recomposition
|
| 482 |
-
|
| 483 |
|
| 484 |
# Return updated preview
|
| 485 |
-
return update_frame_display(
|
| 486 |
|
| 487 |
|
| 488 |
-
def propagate_masks(
|
| 489 |
-
if
|
| 490 |
-
yield "Load a video first."
|
| 491 |
return
|
| 492 |
|
| 493 |
processor = GLOBAL_STATE.processor
|
|
@@ -497,9 +495,11 @@ def propagate_masks(state: AppState):
|
|
| 497 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 498 |
processed = 0
|
| 499 |
|
| 500 |
-
|
|
|
|
| 501 |
|
| 502 |
device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
|
|
|
|
| 503 |
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=GLOBAL_STATE.dtype):
|
| 504 |
for sam2_video_output in model.propagate_in_video_iterator(inference_session):
|
| 505 |
H = inference_session.video_height
|
|
@@ -508,6 +508,7 @@ def propagate_masks(state: AppState):
|
|
| 508 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
| 509 |
|
| 510 |
frame_idx = int(sam2_video_output.frame_idx)
|
|
|
|
| 511 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 512 |
obj_ids_order = list(inference_session.obj_ids)
|
| 513 |
for i, oid in enumerate(obj_ids_order):
|
|
@@ -518,12 +519,20 @@ def propagate_masks(state: AppState):
|
|
| 518 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 519 |
|
| 520 |
processed += 1
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
|
| 526 |
-
def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
|
| 527 |
# Reset only session-related state, keep uploaded video and model
|
| 528 |
if not GLOBAL_STATE.video_frames:
|
| 529 |
# Nothing loaded; keep behavior
|
|
@@ -551,7 +560,7 @@ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
|
|
| 551 |
torch.cuda.empty_cache()
|
| 552 |
except Exception:
|
| 553 |
pass
|
| 554 |
-
ensure_session_for_current_model()
|
| 555 |
|
| 556 |
# Keep current slider index if possible
|
| 557 |
current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
|
|
@@ -561,20 +570,41 @@ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
|
|
| 561 |
slider_value = gr.update(value=current_idx)
|
| 562 |
status = "Session reset. Prompts cleared; video preserved."
|
| 563 |
# clear and reload model and processor
|
| 564 |
-
GLOBAL_STATE.model = None
|
| 565 |
-
GLOBAL_STATE.processor = None
|
| 566 |
-
ensure_session_for_current_model()
|
| 567 |
return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
|
| 568 |
|
| 569 |
|
| 570 |
theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
|
| 571 |
|
| 572 |
with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
|
| 573 |
-
|
| 574 |
|
| 575 |
-
gr.Markdown(
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
with gr.Row():
|
| 580 |
with gr.Column(scale=1):
|
|
@@ -594,17 +624,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 594 |
with gr.Row():
|
| 595 |
obj_id_inp = gr.Number(value=1, precision=0, label="Object ID")
|
| 596 |
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
|
| 597 |
-
clear_old_chk = gr.Checkbox(value=
|
| 598 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
| 599 |
with gr.Column():
|
| 600 |
propagate_btn = gr.Button("Propagate across video", variant="primary")
|
| 601 |
propagate_status = gr.Markdown(visible=True)
|
| 602 |
|
| 603 |
# Wire events
|
| 604 |
-
def _on_video_change(video):
|
| 605 |
-
|
| 606 |
return (
|
| 607 |
-
|
| 608 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 609 |
first_frame,
|
| 610 |
status,
|
|
@@ -612,22 +642,29 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 612 |
|
| 613 |
video_in.change(
|
| 614 |
_on_video_change,
|
| 615 |
-
inputs=[video_in],
|
| 616 |
-
outputs=[
|
| 617 |
show_progress=True,
|
| 618 |
)
|
| 619 |
|
| 620 |
# (moved) Examples are defined above the render button
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
# Examples (place before the render MP4 button) — defined after handler below
|
| 632 |
|
| 633 |
with gr.Row():
|
|
@@ -646,23 +683,23 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 646 |
s.processor = None
|
| 647 |
# Stream progress text while loading (first yield shows text)
|
| 648 |
yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
|
| 649 |
-
ensure_session_for_current_model()
|
| 650 |
if s is not None:
|
| 651 |
s.is_switching_model = False
|
| 652 |
# Final yield hides the text
|
| 653 |
yield gr.update(visible=False, value="")
|
| 654 |
|
| 655 |
-
ckpt_radio.change(_on_ckpt_change, inputs=[
|
| 656 |
|
| 657 |
# Also retrigger session re-init if a video already loaded
|
| 658 |
def _rebind_session_after_ckpt(s: AppState):
|
| 659 |
-
ensure_session_for_current_model()
|
| 660 |
# Reset pending box corner to avoid mismatched state
|
| 661 |
if s is not None:
|
| 662 |
s.pending_box_start = None
|
| 663 |
return gr.update()
|
| 664 |
|
| 665 |
-
ckpt_radio.change(_rebind_session_after_ckpt, inputs=[
|
| 666 |
|
| 667 |
def _sync_frame_idx(state_in: AppState, idx: int):
|
| 668 |
if state_in is not None:
|
|
@@ -671,7 +708,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 671 |
|
| 672 |
frame_slider.change(
|
| 673 |
_sync_frame_idx,
|
| 674 |
-
inputs=[
|
| 675 |
outputs=preview,
|
| 676 |
)
|
| 677 |
|
|
@@ -680,26 +717,37 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 680 |
s.current_obj_id = int(oid)
|
| 681 |
return gr.update()
|
| 682 |
|
| 683 |
-
obj_id_inp.change(_sync_obj_id, inputs=[
|
| 684 |
|
| 685 |
def _sync_label(s: AppState, lab: str):
|
| 686 |
if s is not None and lab is not None:
|
| 687 |
s.current_label = str(lab)
|
| 688 |
return gr.update()
|
| 689 |
|
| 690 |
-
label_radio.change(_sync_label, inputs=[
|
| 691 |
|
| 692 |
def _sync_prompt_type(s: AppState, val: str):
|
| 693 |
if s is not None and val is not None:
|
| 694 |
s.current_prompt_type = str(val)
|
| 695 |
s.pending_box_start = None
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
# Image click to add a point and run forward on that frame
|
| 702 |
-
preview.select(
|
|
|
|
|
|
|
| 703 |
|
| 704 |
# Playback via MP4 rendering only
|
| 705 |
|
|
@@ -747,14 +795,19 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 747 |
except Exception as e:
|
| 748 |
raise gr.Error(f"Failed to render video: {e}")
|
| 749 |
|
| 750 |
-
render_btn.click(_render_video, inputs=[
|
| 751 |
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
|
| 754 |
reset_btn.click(
|
| 755 |
reset_session,
|
| 756 |
-
inputs=
|
| 757 |
-
outputs=[
|
| 758 |
)
|
| 759 |
|
| 760 |
|
|
|
|
| 129 |
self.inference_session = None
|
| 130 |
self.model: Optional[Sam2VideoModel] = None
|
| 131 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 132 |
+
self.device: str = "cuda"
|
| 133 |
self.dtype: torch.dtype = torch.bfloat16
|
| 134 |
self.video_fps: float | None = None
|
| 135 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
|
|
| 158 |
return len(self.video_frames)
|
| 159 |
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
def _model_repo_from_key(key: str) -> str:
|
| 162 |
mapping = {
|
| 163 |
"tiny": "yonigozlan/sam2.1_hiera_tiny_hf",
|
|
|
|
| 168 |
return mapping.get(key, mapping["base_plus"])
|
| 169 |
|
| 170 |
|
| 171 |
+
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
|
| 172 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 173 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
| 174 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
|
|
|
| 186 |
GLOBAL_STATE.processor = None
|
| 187 |
print(f"Loading model from {desired_repo}")
|
| 188 |
device, dtype = get_device_and_dtype()
|
| 189 |
+
# free up the gpu memory
|
| 190 |
+
torch.cuda.empty_cache()
|
| 191 |
+
gc.collect()
|
| 192 |
+
print("device", device)
|
| 193 |
+
model = Sam2VideoModel.from_pretrained(desired_repo)
|
| 194 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 195 |
+
model.to(device, dtype=dtype)
|
|
|
|
| 196 |
|
| 197 |
GLOBAL_STATE.model = model
|
| 198 |
GLOBAL_STATE.processor = processor
|
|
|
|
| 203 |
return model, processor, device, dtype
|
| 204 |
|
| 205 |
|
| 206 |
+
def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
|
| 207 |
"""Ensure the model/processor match the selected repo and inference_session exists.
|
| 208 |
If a video is already loaded, re-initialize the inference session when needed.
|
| 209 |
"""
|
| 210 |
+
model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
|
| 211 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 212 |
if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
|
| 213 |
if GLOBAL_STATE.video_frames:
|
|
|
|
| 238 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 239 |
|
| 240 |
|
| 241 |
+
def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
|
| 242 |
"""Gradio handler: load video, init session, return state, slider bounds, and first frame."""
|
| 243 |
# Reset ONLY video-related fields, keep model loaded
|
| 244 |
GLOBAL_STATE.video_frames = []
|
|
|
|
| 246 |
GLOBAL_STATE.masks_by_frame = {}
|
| 247 |
GLOBAL_STATE.color_by_obj = {}
|
| 248 |
|
| 249 |
+
model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
|
| 250 |
|
| 251 |
# Gradio Video may provide a dict with 'name' or a direct file path
|
| 252 |
video_path: Optional[str] = None
|
|
|
|
| 348 |
return compose_frame(state, frame_idx)
|
| 349 |
|
| 350 |
|
| 351 |
+
def _ensure_color_for_obj(state: AppState, obj_id: int):
|
| 352 |
+
if obj_id not in state.color_by_obj:
|
| 353 |
+
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 354 |
|
| 355 |
|
| 356 |
def on_image_click(
|
|
|
|
| 383 |
if x is None or y is None:
|
| 384 |
raise gr.Error("Could not read click coordinates.")
|
| 385 |
|
| 386 |
+
_ensure_color_for_obj(state, int(obj_id))
|
| 387 |
|
| 388 |
+
processor = state.processor
|
| 389 |
+
model = state.model
|
| 390 |
+
inference_session = state.inference_session
|
| 391 |
|
| 392 |
if state.current_prompt_type == "Boxes":
|
| 393 |
# Two-click box input
|
| 394 |
if state.pending_box_start is None:
|
| 395 |
+
# For boxes, always clear old inputs (points) for this object on this frame
|
| 396 |
+
frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
|
| 397 |
+
frame_clicks[int(obj_id)] = []
|
| 398 |
+
state.composited_frames.pop(int(frame_idx), None)
|
|
|
|
| 399 |
state.pending_box_start = (int(x), int(y))
|
| 400 |
state.pending_box_start_frame_idx = int(frame_idx)
|
| 401 |
state.pending_box_start_obj_id = int(obj_id)
|
|
|
|
| 418 |
frame_idx=int(frame_idx),
|
| 419 |
obj_ids=int(obj_id),
|
| 420 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 421 |
+
clear_old_inputs=True, # For boxes, always clear old inputs
|
| 422 |
)
|
| 423 |
|
| 424 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
| 425 |
obj_boxes = frame_boxes.setdefault(int(obj_id), [])
|
| 426 |
+
# For boxes, always clear old inputs
|
| 427 |
+
obj_boxes.clear()
|
| 428 |
obj_boxes.append((x_min, y_min, x_max, y_max))
|
| 429 |
state.composited_frames.pop(int(frame_idx), None)
|
| 430 |
else:
|
|
|
|
| 452 |
state.composited_frames.pop(int(frame_idx), None)
|
| 453 |
|
| 454 |
# Forward on that frame
|
| 455 |
+
device_type = "cuda" if state.device == "cuda" else "cpu"
|
| 456 |
+
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=state.dtype):
|
| 457 |
outputs = model(
|
| 458 |
inference_session=inference_session,
|
| 459 |
frame_idx=int(frame_idx),
|
|
|
|
| 475 |
mask_2d = mask_i.cpu().numpy().squeeze()
|
| 476 |
masks_for_frame[int(oid)] = mask_2d
|
| 477 |
|
| 478 |
+
state.masks_by_frame[int(frame_idx)] = masks_for_frame
|
| 479 |
# Invalidate cache for this frame to force recomposition
|
| 480 |
+
state.composited_frames.pop(int(frame_idx), None)
|
| 481 |
|
| 482 |
# Return updated preview
|
| 483 |
+
return update_frame_display(state, int(frame_idx))
|
| 484 |
|
| 485 |
|
| 486 |
+
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 487 |
+
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 488 |
+
yield "Load a video first.", gr.update()
|
| 489 |
return
|
| 490 |
|
| 491 |
processor = GLOBAL_STATE.processor
|
|
|
|
| 495 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 496 |
processed = 0
|
| 497 |
|
| 498 |
+
# Initial status; no slider change yet
|
| 499 |
+
yield f"Propagating masks: {processed}/{total}", gr.update()
|
| 500 |
|
| 501 |
device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
|
| 502 |
+
last_frame_idx = 0
|
| 503 |
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=GLOBAL_STATE.dtype):
|
| 504 |
for sam2_video_output in model.propagate_in_video_iterator(inference_session):
|
| 505 |
H = inference_session.video_height
|
|
|
|
| 508 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
| 509 |
|
| 510 |
frame_idx = int(sam2_video_output.frame_idx)
|
| 511 |
+
last_frame_idx = frame_idx
|
| 512 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 513 |
obj_ids_order = list(inference_session.obj_ids)
|
| 514 |
for i, oid in enumerate(obj_ids_order):
|
|
|
|
| 519 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 520 |
|
| 521 |
processed += 1
|
| 522 |
+
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 523 |
+
if processed % 15 == 0 or processed == total:
|
| 524 |
+
yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 525 |
+
else:
|
| 526 |
+
yield f"Propagating masks: {processed}/{total}", gr.update()
|
| 527 |
|
| 528 |
+
# Final status; ensure slider points to last processed frame
|
| 529 |
+
yield (
|
| 530 |
+
f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
| 531 |
+
gr.update(value=last_frame_idx),
|
| 532 |
+
)
|
| 533 |
|
| 534 |
|
| 535 |
+
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
| 536 |
# Reset only session-related state, keep uploaded video and model
|
| 537 |
if not GLOBAL_STATE.video_frames:
|
| 538 |
# Nothing loaded; keep behavior
|
|
|
|
| 560 |
torch.cuda.empty_cache()
|
| 561 |
except Exception:
|
| 562 |
pass
|
| 563 |
+
ensure_session_for_current_model(GLOBAL_STATE)
|
| 564 |
|
| 565 |
# Keep current slider index if possible
|
| 566 |
current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
|
|
|
|
| 570 |
slider_value = gr.update(value=current_idx)
|
| 571 |
status = "Session reset. Prompts cleared; video preserved."
|
| 572 |
# clear and reload model and processor
|
|
|
|
|
|
|
|
|
|
| 573 |
return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
|
| 574 |
|
| 575 |
|
| 576 |
theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
|
| 577 |
|
| 578 |
with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
|
| 579 |
+
GLOBAL_STATE = gr.State(AppState())
|
| 580 |
|
| 581 |
+
gr.Markdown(
|
| 582 |
+
"""
|
| 583 |
+
### SAM2 Video Tracking · powered by Hugging Face 🤗 Transformers
|
| 584 |
+
Segment and track objects across a video with SAM2 (Segment Anything 2). This demo runs the official implementation from the Hugging Face Transformers library for interactive, promptable video segmentation.
|
| 585 |
+
"""
|
| 586 |
+
)
|
| 587 |
+
with gr.Row():
|
| 588 |
+
with gr.Column():
|
| 589 |
+
gr.Markdown(
|
| 590 |
+
"""
|
| 591 |
+
**Quick start**
|
| 592 |
+
- **Load a video**: Upload your own or pick an example below.
|
| 593 |
+
- **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy).
|
| 594 |
+
- **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames.
|
| 595 |
+
- **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically.
|
| 596 |
+
"""
|
| 597 |
+
)
|
| 598 |
+
with gr.Column():
|
| 599 |
+
gr.Markdown(
|
| 600 |
+
"""
|
| 601 |
+
**Working with results**
|
| 602 |
+
- **Preview**: Use the slider to navigate frames and see the current masks.
|
| 603 |
+
- **Propagate**: Click “Propagate across video” to track all defined objects through the entire video. The preview follows progress periodically to keep things responsive.
|
| 604 |
+
- **Export**: Render an MP4 for smooth playback using the original video FPS.
|
| 605 |
+
- **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
|
| 606 |
+
"""
|
| 607 |
+
)
|
| 608 |
|
| 609 |
with gr.Row():
|
| 610 |
with gr.Column(scale=1):
|
|
|
|
| 624 |
with gr.Row():
|
| 625 |
obj_id_inp = gr.Number(value=1, precision=0, label="Object ID")
|
| 626 |
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
|
| 627 |
+
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
|
| 628 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
| 629 |
with gr.Column():
|
| 630 |
propagate_btn = gr.Button("Propagate across video", variant="primary")
|
| 631 |
propagate_status = gr.Markdown(visible=True)
|
| 632 |
|
| 633 |
# Wire events
|
| 634 |
+
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
| 635 |
+
GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
|
| 636 |
return (
|
| 637 |
+
GLOBAL_STATE,
|
| 638 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 639 |
first_frame,
|
| 640 |
status,
|
|
|
|
| 642 |
|
| 643 |
video_in.change(
|
| 644 |
_on_video_change,
|
| 645 |
+
inputs=[GLOBAL_STATE, video_in],
|
| 646 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status],
|
| 647 |
show_progress=True,
|
| 648 |
)
|
| 649 |
|
| 650 |
# (moved) Examples are defined above the render button
|
| 651 |
+
# Each example row must match the number of inputs (GLOBAL_STATE, video_in)
|
| 652 |
+
examples_list = [
|
| 653 |
+
[None, "./tennis.mp4"],
|
| 654 |
+
[None, "./football.mp4"],
|
| 655 |
+
[None, "./basket.mp4"],
|
| 656 |
+
[None, "./hurdles.mp4"],
|
| 657 |
+
]
|
| 658 |
+
with gr.Row():
|
| 659 |
+
gr.Examples(
|
| 660 |
+
examples=examples_list,
|
| 661 |
+
inputs=[GLOBAL_STATE, video_in],
|
| 662 |
+
fn=_on_video_change,
|
| 663 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status],
|
| 664 |
+
label="Examples",
|
| 665 |
+
cache_examples=False,
|
| 666 |
+
examples_per_page=5,
|
| 667 |
+
)
|
| 668 |
# Examples (place before the render MP4 button) — defined after handler below
|
| 669 |
|
| 670 |
with gr.Row():
|
|
|
|
| 683 |
s.processor = None
|
| 684 |
# Stream progress text while loading (first yield shows text)
|
| 685 |
yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
|
| 686 |
+
ensure_session_for_current_model(s)
|
| 687 |
if s is not None:
|
| 688 |
s.is_switching_model = False
|
| 689 |
# Final yield hides the text
|
| 690 |
yield gr.update(visible=False, value="")
|
| 691 |
|
| 692 |
+
ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress])
|
| 693 |
|
| 694 |
# Also retrigger session re-init if a video already loaded
|
| 695 |
def _rebind_session_after_ckpt(s: AppState):
|
| 696 |
+
ensure_session_for_current_model(s)
|
| 697 |
# Reset pending box corner to avoid mismatched state
|
| 698 |
if s is not None:
|
| 699 |
s.pending_box_start = None
|
| 700 |
return gr.update()
|
| 701 |
|
| 702 |
+
ckpt_radio.change(_rebind_session_after_ckpt, inputs=[GLOBAL_STATE], outputs=[])
|
| 703 |
|
| 704 |
def _sync_frame_idx(state_in: AppState, idx: int):
|
| 705 |
if state_in is not None:
|
|
|
|
| 708 |
|
| 709 |
frame_slider.change(
|
| 710 |
_sync_frame_idx,
|
| 711 |
+
inputs=[GLOBAL_STATE, frame_slider],
|
| 712 |
outputs=preview,
|
| 713 |
)
|
| 714 |
|
|
|
|
| 717 |
s.current_obj_id = int(oid)
|
| 718 |
return gr.update()
|
| 719 |
|
| 720 |
+
obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[])
|
| 721 |
|
| 722 |
def _sync_label(s: AppState, lab: str):
|
| 723 |
if s is not None and lab is not None:
|
| 724 |
s.current_label = str(lab)
|
| 725 |
return gr.update()
|
| 726 |
|
| 727 |
+
label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[])
|
| 728 |
|
| 729 |
def _sync_prompt_type(s: AppState, val: str):
|
| 730 |
if s is not None and val is not None:
|
| 731 |
s.current_prompt_type = str(val)
|
| 732 |
s.pending_box_start = None
|
| 733 |
+
is_points = str(val).lower() == "points"
|
| 734 |
+
# Show labels only for points; hide and disable clear_old when boxes
|
| 735 |
+
updates = [
|
| 736 |
+
gr.update(visible=is_points),
|
| 737 |
+
gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
|
| 738 |
+
]
|
| 739 |
+
return updates
|
| 740 |
+
|
| 741 |
+
prompt_type.change(
|
| 742 |
+
_sync_prompt_type,
|
| 743 |
+
inputs=[GLOBAL_STATE, prompt_type],
|
| 744 |
+
outputs=[label_radio, clear_old_chk],
|
| 745 |
+
)
|
| 746 |
|
| 747 |
# Image click to add a point and run forward on that frame
|
| 748 |
+
preview.select(
|
| 749 |
+
on_image_click, [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview
|
| 750 |
+
)
|
| 751 |
|
| 752 |
# Playback via MP4 rendering only
|
| 753 |
|
|
|
|
| 795 |
except Exception as e:
|
| 796 |
raise gr.Error(f"Failed to render video: {e}")
|
| 797 |
|
| 798 |
+
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
| 799 |
|
| 800 |
+
# While propagating, we stream two outputs: status text and slider value updates
|
| 801 |
+
propagate_btn.click(
|
| 802 |
+
propagate_masks,
|
| 803 |
+
inputs=[GLOBAL_STATE],
|
| 804 |
+
outputs=[propagate_status, frame_slider],
|
| 805 |
+
)
|
| 806 |
|
| 807 |
reset_btn.click(
|
| 808 |
reset_session,
|
| 809 |
+
inputs=GLOBAL_STATE,
|
| 810 |
+
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status],
|
| 811 |
)
|
| 812 |
|
| 813 |
|