Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import spaces | |
| from typing import Iterable | |
| import gradio as gr | |
| import torch | |
| import requests | |
| from PIL import Image | |
| from transformers import AutoProcessor, Florence2ForConditionalGeneration | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| colors.steel_blue = colors.Color( | |
| name="steel_blue", | |
| c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", | |
| c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", | |
| c800="#2E5378", c900="#264364", c950="#1E3450", | |
| ) | |
| class SteelBlueTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.steel_blue, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, | |
| text_size=text_size, font=font, font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| steel_blue_theme = SteelBlueTheme() | |
| css = """ | |
| #main-title h1 { | |
| font-size: 2.3em !important; | |
| } | |
| #output-title h2 { | |
| font-size: 2.1em !important; | |
| } | |
| """ | |
| MODEL_IDS = { | |
| "Florence-2-base": "florence-community/Florence-2-base", | |
| "Florence-2-base-ft": "florence-community/Florence-2-base-ft", | |
| "Florence-2-large": "florence-community/Florence-2-large", | |
| "Florence-2-large-ft": "florence-community/Florence-2-large-ft", | |
| } | |
| models = {} | |
| processors = {} | |
| print("Loading Florence-2 models... This may take a while.") | |
| for name, repo_id in MODEL_IDS.items(): | |
| print(f"Loading {name}...") | |
| model = Florence2ForConditionalGeneration.from_pretrained( | |
| repo_id, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) | |
| models[name] = model | |
| processors[name] = processor | |
| print(f"✅ Finished loading {name}.") | |
| print("\n🎉 All models loaded successfully!") | |
| def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str, | |
| max_new_tokens: int = 1024, num_beams: int = 3): | |
| """ | |
| Runs inference using the selected Florence-2 model. | |
| """ | |
| if image is None: | |
| return "Please upload an image to get started." | |
| model = models[model_name] | |
| processor = processors[model_name] | |
| inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| do_sample=False | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| image_size = image.size | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, task=task_prompt, image_size=image_size | |
| ) | |
| return parsed_answer | |
| florence_tasks = [ | |
| "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", | |
| "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>" | |
| ] | |
| url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/venice.jpg?download=true" | |
| example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
| with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title") | |
| gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the results.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290) | |
| task_prompt = gr.Dropdown( | |
| label="Select Task", | |
| choices=florence_tasks, | |
| value="<MORE_DETAILED_CAPTION>" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=list(MODEL_IDS.keys()), | |
| label="Select Model", | |
| value="Florence-2-base" | |
| ) | |
| image_submit = gr.Button("Submit", variant="primary") | |
| with gr.Accordion("Advanced options", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024 | |
| ) | |
| num_beams = gr.Slider( | |
| label="Number of Beams", minimum=1, maximum=10, step=1, value=3 | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Output", elem_id="output-title") | |
| parsed_output = gr.JSON(label="Parsed Answer") | |
| image_submit.click( | |
| fn=run_florence2_inference, | |
| inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams], | |
| outputs=[parsed_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True) |