Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import torch | |
| import supervision as sv | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import spaces | |
| from helpers.file_utils import create_directory, delete_directory, generate_unique_name | |
| from helpers.segment_utils import parse_segmentation, extract_objs | |
| import os | |
| BOX_ANNOTATOR = sv.BoxAnnotator() | |
| LABEL_ANNOTATOR = sv.LabelAnnotator() | |
| MASK_ANNOTATOR = sv.MaskAnnotator() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| VIDEO_TARGET_DIRECTORY = "tmp" | |
| VAE_MODEL = "vae-oid.npz" | |
| COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
| INTRO_TEXT = """ | |
| ## PaliGemma 2 Detection/Segmentation with Supervision - Demo | |
| <div style="display: flex; gap: 10px;"> | |
| <a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md"> | |
| <img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github"> | |
| </a> | |
| <a href="https://huggingface.co/blog/paligemma"> | |
| <img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface"> | |
| </a> | |
| <a href="https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb"> | |
| <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab"> | |
| </a> | |
| <a href="https://arxiv.org/abs/2412.03555"> | |
| <img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper"> | |
| </a> | |
| <a href="https://supervision.roboflow.com/"> | |
| <img src="https://img.shields.io/badge/Supervision-6706CE?style=flat&logo=Roboflow&logoColor=white" alt="Supervision"> | |
| </a> | |
| </div> | |
| PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
| built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
| vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile | |
| model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
| answering, text reading, object detection and object segmentation. | |
| This space show how to use PaliGemma 2 for object detection with supervision. | |
| You can input an image and a text prompt | |
| """ | |
| create_directory(directory_path=VIDEO_TARGET_DIRECTORY) | |
| model_id = "google/paligemma2-3b-pt-448" | |
| model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) | |
| processor = PaliGemmaProcessor.from_pretrained(model_id) | |
| def parse_class_names(prompt): | |
| if not prompt.lower().startswith('detect '): | |
| return [] | |
| classes_text = prompt[7:].strip() | |
| return [cls.strip() for cls in classes_text.split(';') if cls.strip()] | |
| def parse_prompt_type(prompt): | |
| """Determine if the prompt is for detection or segmentation.""" | |
| if prompt.lower().startswith('detect '): | |
| return 'detection', prompt[7:].strip() | |
| elif prompt.lower().startswith('segment '): | |
| return 'segmentation', prompt[8:].strip() | |
| return None, prompt | |
| def paligemma_detection(input_image, input_text, max_new_tokens): | |
| model_inputs = processor(text=input_text, | |
| images=input_image, | |
| return_tensors="pt" | |
| ).to(torch.bfloat16).to(model.device) | |
| input_len = model_inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| generation = generation[0][input_len:] | |
| result = processor.decode(generation, skip_special_tokens=True) | |
| return result | |
| def annotate_image(result, resolution_wh, prompt, cv_image): | |
| class_names = parse_class_names(prompt) | |
| if not class_names: | |
| gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") | |
| return cv_image | |
| detections = sv.Detections.from_lmm( | |
| sv.LMM.PALIGEMMA, | |
| result, | |
| resolution_wh=resolution_wh, | |
| classes=class_names | |
| ) | |
| annotated_image = BOX_ANNOTATOR.annotate( | |
| scene=cv_image.copy(), | |
| detections=detections | |
| ) | |
| annotated_image = LABEL_ANNOTATOR.annotate( | |
| scene=annotated_image, | |
| detections=detections | |
| ) | |
| annotated_image = MASK_ANNOTATOR.annotate( | |
| scene=annotated_image, | |
| detections=detections | |
| ) | |
| annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
| annotated_image = Image.fromarray(annotated_image) | |
| return annotated_image | |
| def process_image(input_image, input_text, max_new_tokens): | |
| cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
| prompt_type, cleaned_prompt = parse_prompt_type(input_text) | |
| if prompt_type == 'detection': | |
| # Existing detection logic | |
| result = paligemma_detection(input_image, input_text, max_new_tokens) | |
| class_names = [cls.strip() for cls in cleaned_prompt.split(';') if cls.strip()] | |
| detections = sv.Detections.from_lmm( | |
| sv.LMM.PALIGEMMA, | |
| result, | |
| resolution_wh=(input_image.width, input_image.height), | |
| classes=class_names | |
| ) | |
| annotated_image = BOX_ANNOTATOR.annotate(scene=cv_image.copy(), detections=detections) | |
| annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections) | |
| annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections) | |
| elif prompt_type == 'segmentation': | |
| # Use parse_segmentation for segmentation tasks | |
| result = paligemma_detection(input_image, input_text, max_new_tokens) | |
| input_image, annotations = parse_segmentation(input_image, result) | |
| # Create annotated image | |
| annotated_image = cv_image.copy() | |
| for mask, label in annotations: | |
| if isinstance(mask, np.ndarray): # If it's a segmentation mask | |
| # Create colored mask | |
| color_idx = hash(label) % len(COLORS) | |
| color = tuple(int(COLORS[color_idx].lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) | |
| colored_mask = np.zeros_like(cv_image) | |
| colored_mask[mask > 0] = color | |
| # Blend mask with image | |
| alpha = 0.5 | |
| annotated_image = cv2.addWeighted(annotated_image, 1, colored_mask, alpha, 0) | |
| # Add label where mask starts | |
| y_coords, x_coords = np.where(mask > 0) | |
| if len(y_coords) > 0 and len(x_coords) > 0: | |
| label_y = y_coords.min() | |
| label_x = x_coords.min() | |
| cv2.putText(annotated_image, label, (label_x, label_y-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
| else: | |
| gr.Warning("Invalid prompt format. Please use 'detect' or 'segment' followed by class names") | |
| return input_image, "Invalid prompt format" | |
| # Convert back to RGB for display | |
| annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
| annotated_image = Image.fromarray(annotated_image) | |
| return annotated_image, result | |
| def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(track_tqdm=True)): | |
| if not input_video: | |
| gr.Info("Please upload a video.") | |
| return None | |
| if not input_text: | |
| gr.Info("Please enter a text prompt.") | |
| return None | |
| class_names = parse_class_names(input_text) | |
| if not class_names: | |
| gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") | |
| return None, None | |
| name = generate_unique_name() | |
| frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) | |
| create_directory(frame_directory_path) | |
| video_info = sv.VideoInfo.from_video_path(input_video) | |
| frame_generator = sv.get_video_frames_generator(input_video) | |
| video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") | |
| results = [] | |
| with sv.VideoSink(video_path, video_info=video_info) as sink: | |
| for frame in progress.tqdm(frame_generator, desc="Processing video"): | |
| pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| model_inputs = processor( | |
| text=input_text, | |
| images=pil_frame, | |
| return_tensors="pt" | |
| ).to(torch.bfloat16).to(model.device) | |
| input_len = model_inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| generation = generation[0][input_len:] | |
| result = processor.decode(generation, skip_special_tokens=True) | |
| detections = sv.Detections.from_lmm( | |
| sv.LMM.PALIGEMMA, | |
| result, | |
| resolution_wh=(video_info.width, video_info.height), | |
| classes=class_names | |
| ) | |
| annotated_frame = BOX_ANNOTATOR.annotate( | |
| scene=frame.copy(), | |
| detections=detections | |
| ) | |
| annotated_frame = LABEL_ANNOTATOR.annotate( | |
| scene=annotated_frame, | |
| detections=detections | |
| ) | |
| annotated_frame = MASK_ANNOTATOR.annotate( | |
| scene=annotated_frame, | |
| detections=detections | |
| ) | |
| results.append(result) | |
| sink.write_frame(annotated_frame) | |
| delete_directory(frame_directory_path) | |
| return video_path, results | |
| with gr.Blocks() as app: | |
| gr.Markdown(INTRO_TEXT) | |
| with gr.Tab("Image Detection/Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| input_text = gr.Textbox( | |
| lines=2, | |
| placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", | |
| label="Enter detection prompt" | |
| ) | |
| max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.") | |
| with gr.Column(): | |
| annotated_image = gr.Image(type="pil", label="Annotated Image") | |
| detection_result = gr.Textbox(label="Detection Result") | |
| gr.Button("Submit").click( | |
| fn=process_image, | |
| inputs=[input_image, input_text, max_new_tokens], | |
| outputs=[annotated_image, detection_result] | |
| ) | |
| with gr.Tab("Video Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| input_text = gr.Textbox( | |
| lines=2, | |
| placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", | |
| label="Enter detection prompt" | |
| ) | |
| max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Annotated Video") | |
| detection_result = gr.Textbox(label="Detection Result") | |
| gr.Button("Process Video").click( | |
| fn=process_video, | |
| inputs=[input_video, input_text, max_new_tokens], | |
| outputs=[output_video, detection_result] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(ssr_mode=False) |