prithivMLmods commited on
Commit
7b33c74
·
verified ·
1 Parent(s): 0389cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -168
app.py CHANGED
@@ -1,184 +1,175 @@
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer
4
- import spaces
5
- import os
6
- import tempfile
7
- from PIL import Image, ImageDraw
8
- import re # Import thư viện regular expression
9
-
10
- # --- 1. Load Model and Tokenizer (Done only once at startup) ---
11
- print("Loading model and tokenizer...")
12
- model_name = "strangervisionhf/deepseek-ocr-transformers-v5"
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
-
15
- # --- FIX 1: Resolve Tokenizer Warnings ---
16
- # Explicitly set the pad_token_id to the eos_token_id. This is a common setup for
17
- # models that are used for open-ended text generation. It resolves the warning.
18
- tokenizer.pad_token_id = tokenizer.eos_token_id
19
-
20
- # Load the model to CPU first; it will be moved to GPU during processing
21
- model = AutoModel.from_pretrained(
22
- model_name,
23
- torch_dtype=torch.bfloat16, # Use bfloat16 for performance and compatibility
24
- trust_remote_code=True,
25
- use_safetensors=True,
26
  )
27
 
28
- # --- FIX 2: Prevent AttributeError ---
29
- # The model's code is incompatible with the newer 'DynamicCache' in transformers.
30
- # Disabling the cache prevents the error-causing code path from being executed.
31
- # This may slightly slow down inference but ensures stability.
32
- model.config.use_cache = False
33
-
34
- model = model.eval()
35
- print("✅ Model loaded successfully.")
36
-
37
- # --- Helper function to find pre-generated result images ---
38
- def find_result_image(path):
39
- for filename in os.listdir(path):
40
- if "grounding" in filename or "result" in filename:
41
- try:
42
- image_path = os.path.join(path, filename)
43
- return Image.open(image_path)
44
- except Exception as e:
45
- print(f"Error opening result image {filename}: {e}")
46
- return None
47
-
48
- # --- 2. Main Processing Function (UPDATED for multi-bbox drawing) ---
49
- @spaces.GPU
50
- def process_ocr_task(image, model_size, task_type, ref_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
- Processes an image with DeepSeek-OCR for all supported tasks.
53
- Now draws ALL detected bounding boxes for ANY task.
54
  """
55
  if image is None:
56
- return "Please upload an image first.", None
57
-
58
- print("🚀 Moving model to GPU...")
59
- model_gpu = model.cuda()
60
- print("✅ Model is on GPU.")
61
-
62
- with tempfile.TemporaryDirectory() as output_path:
63
- # Build the prompt... (same as before)
64
- if task_type == "📝 Free OCR":
65
- prompt = "<image>\nFree OCR."
66
- elif task_type == "📄 Convert to Markdown":
67
- prompt = "<image>\n<|grounding|>Convert the document to markdown."
68
- elif task_type == "📈 Parse Figure":
69
- prompt = "<image>\nParse the figure."
70
- elif task_type == "🔍 Locate Object by Reference":
71
- if not ref_text or ref_text.strip() == "":
72
- raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
73
- prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
74
- else:
75
- prompt = "<image>\nFree OCR."
76
-
77
- temp_image_path = os.path.join(output_path, "temp_image.png")
78
- image.save(temp_image_path)
79
-
80
- # Configure model size... (same as before)
81
- size_configs = {
82
- "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
83
- "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
84
- "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
85
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
86
- "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
87
- }
88
- config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
89
-
90
- print(f"🏃 Running inference with prompt: {prompt}")
91
- text_result = model_gpu.infer(
92
- tokenizer,
93
- prompt=prompt,
94
- image_file=temp_image_path,
95
- output_path=output_path,
96
- base_size=config["base_size"],
97
- image_size=config["image_size"],
98
- crop_mode=config["crop_mode"],
99
- save_results=True,
100
- test_compress=True,
101
- eval_mode=True,
102
- )
103
 
104
- print(f"====\n📄 Text Result: {text_result}\n====")
105
-
106
- # --- NEW LOGIC: Always try to find and draw all bounding boxes ---
107
- result_image_pil = None
108
-
109
- # Define the pattern to find all coordinates like [[280, 15, 696, 997]]
110
- pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
111
- matches = list(pattern.finditer(text_result)) # Use finditer to get all matches
112
-
113
- if matches:
114
- print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.")
115
-
116
- # Create a copy of the original image to draw on
117
- image_with_bboxes = image.copy()
118
- draw = ImageDraw.Draw(image_with_bboxes)
119
- w, h = image.size # Get original image dimensions
120
-
121
- for match in matches:
122
- # Extract coordinates as integers
123
- coords_norm = [int(c) for c in match.groups()]
124
- x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
125
-
126
- # Scale the normalized coordinates (from 1000x1000 space) to the image's actual size
127
- x1 = int(x1_norm / 1000 * w)
128
- y1 = int(y1_norm / 1000 * h)
129
- x2 = int(x2_norm / 1000 * w)
130
- y2 = int(y2_norm / 1000 * h)
131
-
132
- # Draw the rectangle with a red outline, 3 pixels wide
133
- draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
134
-
135
- result_image_pil = image_with_bboxes
136
- else:
137
- # If no coordinates are found in the text, fall back to finding a pre-generated image
138
- print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.")
139
- result_image_pil = find_result_image(output_path)
140
-
141
- return text_result, result_image_pil
142
-
143
-
144
- # --- 3. Build the Gradio Interface (UPDATED) ---
145
- with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
146
- gr.Markdown(
147
- """
148
- # 🐳 Full Demo of DeepSeek-OCR 🐳
149
-
150
- **💡 How to use:**
151
- 1. **Upload an image** using the upload box.
152
- 2. Select a **Resolution**. `Gundam` is recommended for most documents.
153
- 3. Choose a **Task Type**:
154
- - **📝 Free OCR**: Extracts raw text from the image.
155
- - **📄 Convert to Markdown**: Converts the document into Markdown, preserving structure.
156
- - **📈 Parse Figure**: Extracts structured data from charts and figures.
157
- - **🔍 Locate Object by Reference**: Finds a specific object/text.
158
- 4. If this helpful, please give it a like! 🙏 ❤️
159
- """
160
  )
161
 
162
- with gr.Row():
163
- with gr.Column(scale=1):
164
- image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"])
165
- model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="⚙️ Resolution Size")
166
- task_type = gr.Dropdown(choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"], value="📄 Convert to Markdown", label="🚀 Task Type")
167
- ref_text_input = gr.Textbox(label="📝 Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False)
168
- submit_btn = gr.Button("Process Image", variant="primary")
169
 
170
- with gr.Column(scale=2):
171
- output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True)
172
- output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil")
 
173
 
174
- # --- UI Interaction Logic ---
175
- def toggle_ref_text_visibility(task):
176
- return gr.Textbox(visible=True) if task == "🔍 Locate Object by Reference" else gr.Textbox(visible=False)
177
 
178
- task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input)
179
- submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image])
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- # --- 4. Launch the App ---
183
  if __name__ == "__main__":
184
- demo.queue(max_size=20).launch(share=True)
 
1
+ import os
2
+ import sys
3
+ import spaces
4
+ from typing import Iterable
5
  import gradio as gr
6
  import torch
7
+ import requests
8
+ from PIL import Image
9
+ from transformers import AutoProcessor, Florence2ForConditionalGeneration
10
+ from gradio.themes import Soft
11
+ from gradio.themes.utils import colors, fonts, sizes
12
+
13
+ colors.steel_blue = colors.Color(
14
+ name="steel_blue",
15
+ c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
16
+ c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
17
+ c800="#2E5378", c900="#264364", c950="#1E3450",
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
20
+ class SteelBlueTheme(Soft):
21
+ def __init__(
22
+ self,
23
+ *,
24
+ primary_hue: colors.Color | str = colors.gray,
25
+ secondary_hue: colors.Color | str = colors.steel_blue,
26
+ neutral_hue: colors.Color | str = colors.slate,
27
+ text_size: sizes.Size | str = sizes.text_lg,
28
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
29
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
30
+ ),
31
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
32
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
33
+ ),
34
+ ):
35
+ super().__init__(
36
+ primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
37
+ text_size=text_size, font=font, font_mono=font_mono,
38
+ )
39
+ super().set(
40
+ background_fill_primary="*primary_50",
41
+ background_fill_primary_dark="*primary_900",
42
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
43
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
44
+ button_primary_text_color="white",
45
+ button_primary_text_color_hover="white",
46
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
47
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
48
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
49
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
50
+ slider_color="*secondary_500",
51
+ slider_color_dark="*secondary_600",
52
+ block_title_text_weight="600",
53
+ block_border_width="3px",
54
+ block_shadow="*shadow_drop_lg",
55
+ button_primary_shadow="*shadow_drop_lg",
56
+ button_large_padding="11px",
57
+ color_accent_soft="*primary_100",
58
+ block_label_background_fill="*primary_200",
59
+ )
60
+
61
+ steel_blue_theme = SteelBlueTheme()
62
+
63
+ css = """
64
+ #main-title h1 {
65
+ font-size: 2.3em !important;
66
+ }
67
+ #output-title h2 {
68
+ font-size: 2.1em !important;
69
+ }
70
+ """
71
+
72
+ MODEL_IDS = {
73
+ "Florence-2-base": "florence-community/Florence-2-base",
74
+ "Florence-2-base-ft": "florence-community/Florence-2-base-ft",
75
+ "Florence-2-large": "florence-community/Florence-2-large",
76
+ "Florence-2-large-ft": "florence-community/Florence-2-large-ft",
77
+ }
78
+
79
+ models = {}
80
+ processors = {}
81
+
82
+ print("Loading Florence-2 models... This may take a while.")
83
+ for name, repo_id in MODEL_IDS.items():
84
+ print(f"Loading {name}...")
85
+ model = Florence2ForConditionalGeneration.from_pretrained(
86
+ repo_id,
87
+ dtype=torch.bfloat16,
88
+ device_map="auto",
89
+ trust_remote_code=True
90
+ )
91
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
92
+ models[name] = model
93
+ processors[name] = processor
94
+ print(f"✅ Finished loading {name}.")
95
+
96
+ print("\n🎉 All models loaded successfully!")
97
+
98
+ @spaces.GPU(duration=30)
99
+ def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
100
+ max_new_tokens: int = 1024, num_beams: int = 3):
101
  """
102
+ Runs inference using the selected Florence-2 model.
 
103
  """
104
  if image is None:
105
+ return "Please upload an image to get started."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ model = models[model_name]
108
+ processor = processors[model_name]
109
+
110
+ inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
111
+
112
+ generated_ids = model.generate(
113
+ input_ids=inputs["input_ids"],
114
+ pixel_values=inputs["pixel_values"],
115
+ max_new_tokens=max_new_tokens,
116
+ num_beams=num_beams,
117
+ do_sample=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
 
 
 
121
 
122
+ image_size = image.size
123
+ parsed_answer = processor.post_process_generation(
124
+ generated_text, task=task_prompt, image_size=image_size
125
+ )
126
 
127
+ return parsed_answer
 
 
128
 
129
+ florence_tasks = [
130
+ "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
131
+ "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
132
+ ]
133
 
134
+ url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/venice.jpg?download=true"
135
+ example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
136
+
137
+ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
138
+ gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
139
+ gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the results.")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=2):
143
+ image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290)
144
+ task_prompt = gr.Dropdown(
145
+ label="Select Task",
146
+ choices=florence_tasks,
147
+ value="<MORE_DETAILED_CAPTION>"
148
+ )
149
+ model_choice = gr.Radio(
150
+ choices=list(MODEL_IDS.keys()),
151
+ label="Select Model",
152
+ value="Florence-2-base"
153
+ )
154
+ image_submit = gr.Button("Submit", variant="primary")
155
+
156
+ with gr.Accordion("Advanced options", open=False):
157
+ max_new_tokens = gr.Slider(
158
+ label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024
159
+ )
160
+ num_beams = gr.Slider(
161
+ label="Number of Beams", minimum=1, maximum=10, step=1, value=3
162
+ )
163
+
164
+ with gr.Column(scale=3):
165
+ gr.Markdown("## Output", elem_id="output-title")
166
+ parsed_output = gr.JSON(label="Parsed Answer")
167
+
168
+ image_submit.click(
169
+ fn=run_florence2_inference,
170
+ inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
171
+ outputs=[parsed_output]
172
+ )
173
 
 
174
  if __name__ == "__main__":
175
+ demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)