Spaces:
Runtime error
Runtime error
| #@title Get bounding boxes for the subject | |
| from transformers import pipeline | |
| from moviepy.editor import VideoFileClip | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import tqdm | |
| import pickle | |
| import torch | |
| checkpoint = "google/owlvit-large-patch14" | |
| detector = pipeline(model=checkpoint, task="zero-shot-object-detection", cache_dir="/coc/pskynet4/yashjain/", device='cuda:0') | |
| # from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| # processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| # model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| # def owl_inference(image, text): | |
| # inputs = inputs = processor(text=text, images=image, return_tensors="pt") | |
| # outputs = model(**inputs) | |
| # target_sizes = torch.Tensor([image.size[::-1]]) | |
| # results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes) | |
| # return results[0]['boxes'] | |
| def find_surrounding_masks(mask_presence): | |
| # Finds the indices of the surrounding masks for each gap | |
| gap_info = [] | |
| start = None | |
| for i, present in enumerate(mask_presence): | |
| if present and start is not None: | |
| end = i | |
| gap_info.append((start, end)) | |
| start = None | |
| elif not present and start is None and i > 0: | |
| start = i - 1 | |
| # Handle the special case where the gap is at the end | |
| if start is not None: | |
| gap_info.append((start, len(mask_presence))) | |
| return gap_info | |
| def copy_edge_masks(mask_list, mask_presence): | |
| if not mask_presence[-1]: | |
| # Find the last present mask and copy it to the end | |
| for i in reversed(range(len(mask_presence))): | |
| if mask_presence[i]: | |
| mask_list[i+1:] = [mask_list[i]] * (len(mask_presence) - i - 1) | |
| break | |
| def interpolate_masks(mask_list, mask_presence): | |
| # Ensure the mask list and mask presence list are the same length | |
| assert len(mask_list) == len(mask_presence), "Mask list and presence list must have the same length." | |
| # Copy edge masks if there are gaps at the start or end | |
| # copy_edge_masks(mask_list, mask_presence) | |
| # Find surrounding masks for gaps | |
| gap_info = find_surrounding_masks(mask_presence) | |
| # Interpolate the masks in the gaps | |
| for start, end in gap_info: | |
| end = min(end, len(mask_list)-1) | |
| num_steps = end - start - 1 | |
| prev_mask = mask_list[start] | |
| next_mask = mask_list[end] | |
| step = (next_mask - prev_mask) / (num_steps + 1) | |
| interpolated_masks = [(prev_mask + step * (i + 1)).round().astype(int) for i in range(num_steps)] | |
| mask_list[start + 1:end] = interpolated_masks | |
| return mask_list | |
| def get_bounding_boxes(clip_path, subject): | |
| # Read video from the path | |
| clip = VideoFileClip(clip_path) | |
| all_bboxes = [] | |
| bbox_present = [] | |
| num_bb = 0 | |
| for fidx,frame in enumerate(clip.iter_frames()): | |
| if fidx > 24: break | |
| frame = Image.fromarray(frame) | |
| predictions = detector( | |
| frame, | |
| candidate_labels=[subject,], | |
| ) | |
| try: | |
| bbox = predictions[0]["box"] | |
| bbox = (bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"]) | |
| # Get a zeros array of the same size as the frame | |
| canvas = np.zeros(frame.size[::-1]) | |
| # Draw the bounding box on the canvas | |
| canvas[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1 | |
| # Add the canvas to the list of bounding boxes | |
| all_bboxes.append(canvas) | |
| bbox_present.append(True) | |
| num_bb += 1 | |
| except Exception as e: | |
| # Append an empty canvas, we will interpolate later | |
| all_bboxes.append(np.zeros(frame.size[::-1])) | |
| bbox_present.append(False) | |
| continue | |
| # Design decision | |
| interpolated_masks = interpolate_masks(all_bboxes, bbox_present) | |
| return interpolated_masks, num_bb | |
| import json | |
| BASE_DIR = '/scr/clips_downsampled_5fps_downsized_224x224' | |
| annotations = json.load(open('/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/ssv2_label_ssv2_template/ssv2_ret_label_val_small_filtered.json', 'r')) | |
| records_with_masks = [] | |
| ridx = 0 | |
| for idx,record in tqdm.tqdm(enumerate(annotations)): | |
| video_id = record['video'] | |
| print(f"{record['caption']} - {record['nouns']}") | |
| # for video_id in video_ids: | |
| new_record = record.copy() | |
| new_record['video'] = video_id.replace('webm', 'mp4') | |
| all_masks = [] | |
| all_num_bb = [] | |
| for subject in record['nouns']: | |
| masks, num_bb = get_bounding_boxes(clip_path=os.path.join(BASE_DIR, video_id.replace('webm', 'mp4')), subject=subject) | |
| all_masks.append(masks) | |
| all_num_bb.append(num_bb) | |
| try: | |
| print(f"{record['video']} , subj - {record['nouns']}, bb - {all_num_bb}") | |
| except: | |
| continue | |
| new_record['masks'] = all_masks | |
| records_with_masks.append(new_record) | |
| ridx += 1 | |
| if ridx % 100 == 0: | |
| with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f: | |
| pickle.dump(records_with_masks, f) | |
| with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f: | |
| pickle.dump(records_with_masks, f) |