Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -408,6 +408,7 @@ def get_mask_sam_process(
|
|
| 408 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 409 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
| 410 |
|
|
|
|
| 411 |
#@spaces.GPU
|
| 412 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
| 413 |
# use bfloat16 for the entire notebook
|
|
@@ -505,6 +506,107 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 505 |
codec='libx264'
|
| 506 |
)
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
| 509 |
|
| 510 |
def update_ui(vis_frame_type):
|
|
|
|
| 408 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 409 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
| 410 |
|
| 411 |
+
'''
|
| 412 |
#@spaces.GPU
|
| 413 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
| 414 |
# use bfloat16 for the entire notebook
|
|
|
|
| 506 |
codec='libx264'
|
| 507 |
)
|
| 508 |
|
| 509 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
| 510 |
+
'''
|
| 511 |
+
|
| 512 |
+
import json
|
| 513 |
+
import numpy as np
|
| 514 |
+
|
| 515 |
+
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
| 516 |
+
# use bfloat16 for the entire notebook
|
| 517 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 518 |
+
|
| 519 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 520 |
+
# turn on tfloat32 for Ampere GPUs
|
| 521 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 522 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 523 |
+
|
| 524 |
+
#### PROPAGATION ####
|
| 525 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 526 |
+
# set predictor
|
| 527 |
+
inference_state = stored_inference_state
|
| 528 |
+
|
| 529 |
+
if torch.cuda.is_available():
|
| 530 |
+
inference_state["device"] = 'cuda'
|
| 531 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 532 |
+
else:
|
| 533 |
+
inference_state["device"] = 'cpu'
|
| 534 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
|
| 535 |
+
|
| 536 |
+
frame_names = stored_frame_names
|
| 537 |
+
video_dir = video_frames_dir
|
| 538 |
+
|
| 539 |
+
# Define a directory to save the JPEG images
|
| 540 |
+
frames_output_dir = "frames_output_images"
|
| 541 |
+
os.makedirs(frames_output_dir, exist_ok=True)
|
| 542 |
+
|
| 543 |
+
# Initialize a list to store file paths of saved images
|
| 544 |
+
jpeg_images = []
|
| 545 |
+
|
| 546 |
+
# Initialize a list to store mask area ratios
|
| 547 |
+
mask_area_ratios = []
|
| 548 |
+
|
| 549 |
+
# run propagation throughout the video and collect the results in a dict
|
| 550 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
| 551 |
+
out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False)
|
| 552 |
+
print(out_obj_ids)
|
| 553 |
+
for frame_idx in range(0, inference_state['num_frames']):
|
| 554 |
+
video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx] > 0.0).cpu().numpy()}
|
| 555 |
+
|
| 556 |
+
# Calculate mask area ratio
|
| 557 |
+
mask = video_segments[frame_idx][out_obj_ids[0]]
|
| 558 |
+
mask_area = np.sum(mask) # Number of True pixels in the mask
|
| 559 |
+
total_area = mask.shape[0] * mask.shape[1] # Total number of pixels in the frame
|
| 560 |
+
mask_area_ratio = mask_area / total_area # Ratio of mask area to total area
|
| 561 |
+
|
| 562 |
+
mask_area_ratio = mask_area / np.ones_like(mask).sum()
|
| 563 |
+
|
| 564 |
+
mask_area_ratios.append(mask_area_ratio)
|
| 565 |
+
|
| 566 |
+
# Save mask area ratios as a JSON file
|
| 567 |
+
mask_area_ratios_dict = {f"frame_{frame_idx}": ratio for frame_idx, ratio in enumerate(mask_area_ratios)}
|
| 568 |
+
with open("mask_area_ratios.json", "w") as f:
|
| 569 |
+
json.dump(mask_area_ratios_dict, f, indent=4)
|
| 570 |
+
|
| 571 |
+
# render the segmentation results every few frames
|
| 572 |
+
if vis_frame_type == "check":
|
| 573 |
+
vis_frame_stride = 15
|
| 574 |
+
elif vis_frame_type == "render":
|
| 575 |
+
vis_frame_stride = 1
|
| 576 |
+
|
| 577 |
+
plt.close("all")
|
| 578 |
+
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
| 579 |
+
plt.figure(figsize=(6, 4))
|
| 580 |
+
plt.title(f"frame {out_frame_idx}")
|
| 581 |
+
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
|
| 582 |
+
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
| 583 |
+
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
| 584 |
+
|
| 585 |
+
# Define the output filename and save the figure as a JPEG file
|
| 586 |
+
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
| 587 |
+
plt.savefig(output_filename, format='jpg')
|
| 588 |
+
|
| 589 |
+
# Close the plot
|
| 590 |
+
plt.close()
|
| 591 |
+
|
| 592 |
+
# Append the file path to the list
|
| 593 |
+
jpeg_images.append(output_filename)
|
| 594 |
+
|
| 595 |
+
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 596 |
+
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 597 |
+
|
| 598 |
+
torch.cuda.empty_cache()
|
| 599 |
+
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 600 |
+
|
| 601 |
+
if vis_frame_type == "check":
|
| 602 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), mask_area_ratios_dict
|
| 603 |
+
elif vis_frame_type == "render":
|
| 604 |
+
# Create a video clip from the image sequence
|
| 605 |
+
original_fps = get_video_fps(video_in)
|
| 606 |
+
clip = ImageSequenceClip(jpeg_images, fps=original_fps // 6)
|
| 607 |
+
final_vid_output_path = "output_video.mp4"
|
| 608 |
+
clip.write_videofile(final_vid_output_path, codec='libx264')
|
| 609 |
+
|
| 610 |
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
| 611 |
|
| 612 |
def update_ui(vis_frame_type):
|