Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from sam2.build_sam import build_sam2_video_predictor, build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| import cv2 | |
| import traceback | |
| import matplotlib.pyplot as plt | |
| import ffmpeg | |
| from utils import load_model_without_flash_attn | |
| # CUDA optimizations | |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Initialize models | |
| sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" | |
| model_cfg = "sam2_hiera_l.yaml" | |
| video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) | |
| sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") | |
| image_predictor = SAM2ImagePredictor(sam2_model) | |
| model_id = 'microsoft/Florence-2-large' | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_florence_model(): | |
| return AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).eval().to(device) | |
| florence_model = load_model_without_flash_attn(load_florence_model) | |
| florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| def apply_color_mask(frame, mask, obj_id): | |
| cmap = plt.get_cmap("tab10") | |
| color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors | |
| # Ensure mask has the correct shape | |
| if mask.ndim == 4: | |
| mask = mask.squeeze() # Remove singleton dimensions | |
| if mask.ndim == 3 and mask.shape[0] == 1: | |
| mask = mask[0] # Take the first channel if it's a single-channel 3D array | |
| # Reshape mask to match frame dimensions | |
| mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| # Expand dimensions of mask and color for broadcasting | |
| mask = np.expand_dims(mask, axis=2) | |
| color = color.reshape(1, 1, 3) | |
| colored_mask = mask * color | |
| return frame * (1 - mask) + colored_mask * 255 | |
| def run_florence(image, text_input): | |
| task_prompt = '<OPEN_VOCABULARY_DETECTION>' | |
| prompt = task_prompt + text_input | |
| inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16) | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"].cuda(), | |
| pixel_values=inputs["pixel_values"].cuda(), | |
| max_new_tokens=1024, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = florence_processor.post_process_generation( | |
| generated_text, | |
| task=task_prompt, | |
| image_size=(image.width, image.height) | |
| ) | |
| bboxes = parsed_answer[task_prompt]['bboxes'] | |
| if not bboxes: | |
| print(f"No objects detected for prompt: '{text_input}'. Trying with a default bounding box.") | |
| # Return a default bounding box covering the entire image | |
| return [0, 0, image.width, image.height] | |
| return bboxes[0] | |
| def remove_directory_contents(directory): | |
| for root, dirs, files in os.walk(directory, topdown=False): | |
| for name in files: | |
| os.remove(os.path.join(root, name)) | |
| for name in dirs: | |
| os.rmdir(os.path.join(root, name)) | |
| def process_video(video_path, prompt, target_fps=30, max_dimension=640): | |
| try: | |
| # Get video info | |
| probe = ffmpeg.probe(video_path) | |
| video_info = next(s for s in probe['streams'] if s['codec_type'] == 'video') | |
| width = int(video_info['width']) | |
| height = int(video_info['height']) | |
| original_fps = eval(video_info['r_frame_rate']) | |
| # Calculate new dimensions while maintaining aspect ratio | |
| if width > height: | |
| if width > max_dimension: | |
| new_width = max_dimension | |
| new_height = int(height * (max_dimension / width)) | |
| else: | |
| new_width = width | |
| new_height = height | |
| else: | |
| if height > max_dimension: | |
| new_height = max_dimension | |
| new_width = int(width * (max_dimension / height)) | |
| else: | |
| new_width = width | |
| new_height = height | |
| # Determine target fps | |
| fps = min(original_fps, target_fps) | |
| print(f"Original video: {width}x{height}, {original_fps} fps") | |
| print(f"Processing at: {new_width}x{new_height}, {fps} fps") | |
| # Read and resize frames | |
| out, _ = ( | |
| ffmpeg | |
| .input(video_path) | |
| .filter('fps', fps=fps) | |
| .filter('scale', width=new_width, height=new_height) | |
| .output('pipe:', format='rawvideo', pix_fmt='rgb24') | |
| .run(capture_stdout=True) | |
| ) | |
| frames = np.frombuffer(out, np.uint8).reshape([-1, new_height, new_width, 3]) | |
| print(f"Read {len(frames)} frames") | |
| # Florence detection on first frame | |
| first_frame = Image.fromarray(frames[0]) | |
| mask_box = run_florence(first_frame, prompt) | |
| print("Original mask box:", mask_box) | |
| # Convert mask_box to numpy array | |
| mask_box = np.array(mask_box) | |
| print("Reshaped mask box:", mask_box) | |
| # SAM2 segmentation on first frame | |
| image_predictor.set_image(first_frame) | |
| masks, _, _ = image_predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=mask_box[None, :], | |
| multimask_output=False, | |
| ) | |
| print("masks.shape", masks.shape) | |
| mask = masks.squeeze().astype(bool) | |
| print("Mask shape:", mask.shape) | |
| print("Frame shape:", frames[0].shape) | |
| # SAM2 video propagation | |
| temp_dir = "temp_frames" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| for i, frame in enumerate(frames): | |
| Image.fromarray(frame).save(os.path.join(temp_dir, f"{i:04d}.jpg")) | |
| print(f"Saved {len(frames)} temporary frames") | |
| inference_state = video_predictor.init_state(video_path=temp_dir) | |
| _, _, _ = video_predictor.add_new_mask( | |
| inference_state=inference_state, | |
| frame_idx=0, | |
| obj_id=1, | |
| mask=mask | |
| ) | |
| video_segments = {} | |
| for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): | |
| video_segments[out_frame_idx] = { | |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
| for i, out_obj_id in enumerate(out_obj_ids) | |
| } | |
| print('Segmenting for main vid done') | |
| print(f"Number of segmented frames: {len(video_segments)}") | |
| # Apply segmentation masks to frames | |
| all_segmented_frames = [] | |
| for i, frame in enumerate(frames): | |
| if i in video_segments: | |
| for out_obj_id, mask in video_segments[i].items(): | |
| frame = apply_color_mask(frame, mask, out_obj_id) | |
| all_segmented_frames.append(frame.astype(np.uint8)) | |
| else: | |
| all_segmented_frames.append(frame) | |
| print(f"Applied masks to {len(all_segmented_frames)} frames") | |
| # Clean up temporary files | |
| remove_directory_contents(temp_dir) | |
| os.rmdir(temp_dir) | |
| # Write output video using ffmpeg | |
| output_path = "segmented_video.mp4" | |
| process = ( | |
| ffmpeg | |
| .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{new_width}x{new_height}', r=fps) | |
| .output(output_path, pix_fmt='yuv420p') | |
| .overwrite_output() | |
| .run_async(pipe_stdin=True) | |
| ) | |
| for frame in all_segmented_frames: | |
| process.stdin.write(frame.tobytes()) | |
| process.stdin.close() | |
| process.wait() | |
| if not os.path.exists(output_path): | |
| raise ValueError(f"Output video file was not created: {output_path}") | |
| print(f"Successfully created output video: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"Error in process_video: {str(e)}") | |
| print(traceback.format_exc()) # This will print the full stack trace | |
| return None | |
| def segment_video(video_file, prompt): | |
| if video_file is None: | |
| return None | |
| output_video = process_video(video_file, prompt) | |
| return output_video | |
| demo = gr.Interface( | |
| fn=segment_video, | |
| inputs=[ | |
| gr.Video(label="Upload Video (Keep it under 10 seconds for this demo)"), | |
| gr.Textbox(label="Enter text prompt for object detection (eg - Gymnast , Car ) ") | |
| ], | |
| outputs=gr.Video(label="Segmented Video"), | |
| title="Text-Prompted Video Object Segmentation with SAMv2", | |
| description=""" | |
| This demo uses [Florence-2](https://huggingface.co/microsoft/Florence-2-large), to enable text-prompted object detection for [SAM2](https://github.com/facebookresearch/segment-anything). | |
| 1. Upload a short video (< 6-7 seconds , you can clone this space on larger GPU for longer vids) | |
| 2. Describe the object to segment (The object should be visible in the first frame). | |
| 3. Get your segmented video. | |
| """ | |
| ) | |
| demo.launch() |