Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,8 +55,8 @@ def clear_points(image):
|
|
| 55 |
# we clean all
|
| 56 |
return [
|
| 57 |
image, # first_frame_path
|
| 58 |
-
[], # tracking_points
|
| 59 |
-
[], # trackings_input_label
|
| 60 |
image, # points_map
|
| 61 |
#gr.State() # stored_inference_state
|
| 62 |
]
|
|
@@ -119,8 +119,8 @@ def preprocess_video_in(video_path):
|
|
| 119 |
|
| 120 |
return [
|
| 121 |
first_frame, # first_frame_path
|
| 122 |
-
|
| 123 |
-
|
| 124 |
first_frame, # input_first_frame_image
|
| 125 |
first_frame, # points_map
|
| 126 |
extracted_frames_output_dir, # video_frames_dir
|
|
@@ -130,7 +130,6 @@ def preprocess_video_in(video_path):
|
|
| 130 |
gr.update(open=False) # video_in_drawer
|
| 131 |
]
|
| 132 |
|
| 133 |
-
@spaces.GPU(duration=120)
|
| 134 |
def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
|
| 135 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
| 136 |
|
|
@@ -166,13 +165,13 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
|
|
| 166 |
|
| 167 |
return tracking_points, trackings_input_label, selected_point_map
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
|
| 172 |
-
|
| 173 |
-
#
|
| 174 |
-
|
| 175 |
-
|
| 176 |
|
| 177 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
| 178 |
if random_color:
|
|
@@ -218,8 +217,7 @@ def load_model(checkpoint):
|
|
| 218 |
# return [sam2_checkpoint, model_cfg]
|
| 219 |
|
| 220 |
|
| 221 |
-
|
| 222 |
-
@spaces.GPU(duration=120)
|
| 223 |
def get_mask_sam_process(
|
| 224 |
stored_inference_state,
|
| 225 |
input_first_frame_image,
|
|
@@ -315,7 +313,7 @@ def get_mask_sam_process(
|
|
| 315 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 316 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
| 317 |
|
| 318 |
-
@spaces.GPU(duration=
|
| 319 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
| 320 |
#### PROPAGATION ####
|
| 321 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
|
@@ -415,36 +413,25 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
|
|
| 415 |
frame_number = int(match.group(1))
|
| 416 |
ann_frame_idx = frame_number
|
| 417 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
| 418 |
-
return [], [], new_working_frame, new_working_frame
|
| 419 |
-
|
| 420 |
|
| 421 |
-
@spaces.GPU(duration=120)
|
| 422 |
def reset_propagation(first_frame_path, predictor, stored_inference_state):
|
| 423 |
|
| 424 |
predictor.reset_state(stored_inference_state)
|
| 425 |
# print(f"RESET State: {stored_inference_state} ")
|
| 426 |
-
return first_frame_path, [], [], gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
|
| 427 |
|
| 428 |
|
| 429 |
with gr.Blocks(css=css) as demo:
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
first_frame_path = None
|
| 440 |
-
tracking_points = []
|
| 441 |
-
trackings_input_label = []
|
| 442 |
-
video_frames_dir = None
|
| 443 |
-
scanned_frames = None
|
| 444 |
-
loaded_predictor = None
|
| 445 |
-
stored_inference_state = None
|
| 446 |
-
stored_frame_names = None
|
| 447 |
-
available_frames_to_check = []
|
| 448 |
with gr.Column():
|
| 449 |
gr.Markdown(
|
| 450 |
"""
|
|
|
|
| 55 |
# we clean all
|
| 56 |
return [
|
| 57 |
image, # first_frame_path
|
| 58 |
+
gr.State([]), # tracking_points
|
| 59 |
+
gr.State([]), # trackings_input_label
|
| 60 |
image, # points_map
|
| 61 |
#gr.State() # stored_inference_state
|
| 62 |
]
|
|
|
|
| 119 |
|
| 120 |
return [
|
| 121 |
first_frame, # first_frame_path
|
| 122 |
+
[], # tracking_points
|
| 123 |
+
[], # trackings_input_label
|
| 124 |
first_frame, # input_first_frame_image
|
| 125 |
first_frame, # points_map
|
| 126 |
extracted_frames_output_dir, # video_frames_dir
|
|
|
|
| 130 |
gr.update(open=False) # video_in_drawer
|
| 131 |
]
|
| 132 |
|
|
|
|
| 133 |
def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
|
| 134 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
| 135 |
|
|
|
|
| 165 |
|
| 166 |
return tracking_points, trackings_input_label, selected_point_map
|
| 167 |
|
| 168 |
+
# use bfloat16 for the entire notebook
|
| 169 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 170 |
|
| 171 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 172 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
| 173 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 174 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 175 |
|
| 176 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
| 177 |
if random_color:
|
|
|
|
| 217 |
# return [sam2_checkpoint, model_cfg]
|
| 218 |
|
| 219 |
|
| 220 |
+
|
|
|
|
| 221 |
def get_mask_sam_process(
|
| 222 |
stored_inference_state,
|
| 223 |
input_first_frame_image,
|
|
|
|
| 313 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 314 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
| 315 |
|
| 316 |
+
@spaces.GPU(duration=180)
|
| 317 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
| 318 |
#### PROPAGATION ####
|
| 319 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
|
|
|
| 413 |
frame_number = int(match.group(1))
|
| 414 |
ann_frame_idx = frame_number
|
| 415 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
| 416 |
+
return gr.State([]), gr.State([]), new_working_frame, new_working_frame
|
|
|
|
| 417 |
|
|
|
|
| 418 |
def reset_propagation(first_frame_path, predictor, stored_inference_state):
|
| 419 |
|
| 420 |
predictor.reset_state(stored_inference_state)
|
| 421 |
# print(f"RESET State: {stored_inference_state} ")
|
| 422 |
+
return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
|
| 423 |
|
| 424 |
|
| 425 |
with gr.Blocks(css=css) as demo:
|
| 426 |
+
first_frame_path = gr.State()
|
| 427 |
+
tracking_points = gr.State([])
|
| 428 |
+
trackings_input_label = gr.State([])
|
| 429 |
+
video_frames_dir = gr.State()
|
| 430 |
+
scanned_frames = gr.State()
|
| 431 |
+
loaded_predictor = gr.State()
|
| 432 |
+
stored_inference_state = gr.State()
|
| 433 |
+
stored_frame_names = gr.State()
|
| 434 |
+
available_frames_to_check = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
with gr.Column():
|
| 436 |
gr.Markdown(
|
| 437 |
"""
|