Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import uuid
|
| 4 |
-
import json
|
| 5 |
import time
|
| 6 |
-
import
|
|
|
|
| 7 |
from threading import Thread
|
| 8 |
|
| 9 |
import gradio as gr
|
|
@@ -14,336 +14,373 @@ from PIL import Image, ImageOps
|
|
| 14 |
import cv2
|
| 15 |
|
| 16 |
from transformers import (
|
| 17 |
-
Qwen2VLForConditionalGeneration,
|
| 18 |
Qwen2_5_VLForConditionalGeneration,
|
| 19 |
-
AutoModelForCausalLM,
|
| 20 |
AutoModelForVision2Seq,
|
| 21 |
AutoProcessor,
|
| 22 |
TextIteratorStreamer,
|
| 23 |
)
|
| 24 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
# If you have 'docling_core' installed, you can uncomment them.
|
| 28 |
-
# from docling_core.types.doc import DoclingDocument, DocTagsDocument
|
| 29 |
-
|
| 30 |
-
import re
|
| 31 |
-
import ast
|
| 32 |
-
import html
|
| 33 |
-
|
| 34 |
-
# Constants for text generation
|
| 35 |
MAX_MAX_NEW_TOKENS = 5120
|
| 36 |
DEFAULT_MAX_NEW_TOKENS = 3072
|
| 37 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
| 38 |
-
|
| 39 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 40 |
|
| 41 |
# --- Model Loading ---
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
MODEL_ID_X,
|
| 80 |
-
trust_remote_code=True,
|
| 81 |
-
torch_dtype=torch.float16
|
| 82 |
-
).to(device).eval()
|
| 83 |
-
|
| 84 |
-
# Thyme-RL
|
| 85 |
-
MODEL_ID_N = "Kwai-Keye/Thyme-RL"
|
| 86 |
-
processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
|
| 87 |
-
model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 88 |
-
MODEL_ID_N,
|
| 89 |
-
trust_remote_code=True,
|
| 90 |
-
torch_dtype=torch.float16
|
| 91 |
-
).to(device).eval()
|
| 92 |
|
| 93 |
|
| 94 |
# --- Preprocessing and Helper Functions ---
|
| 95 |
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
|
| 96 |
-
"""Add random padding to an image
|
| 97 |
image = image.convert("RGB")
|
| 98 |
width, height = image.size
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
pad_h = int(height * pad_h_percent)
|
| 103 |
-
corner_pixel = image.getpixel((0, 0)) # Top-left corner
|
| 104 |
-
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
|
| 105 |
return padded_image
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
"""
|
| 109 |
-
|
| 110 |
-
max_value = max(values) if values else 1
|
| 111 |
-
return [round((v / max_value) * target_max) for v in values]
|
| 112 |
-
|
| 113 |
-
def process_match(match):
|
| 114 |
-
num_list = ast.literal_eval(match.group(0))
|
| 115 |
-
normalized = normalize_list(num_list)
|
| 116 |
-
return "".join([f"<loc_{num}>" for num in normalized])
|
| 117 |
-
|
| 118 |
-
pattern = r"\[([\d\.\s,]+)\]"
|
| 119 |
-
normalized_text = re.sub(pattern, process_match, text)
|
| 120 |
-
return normalized_text
|
| 121 |
-
|
| 122 |
-
def downsample_video(video_path):
|
| 123 |
-
"""Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
|
| 124 |
vidcap = cv2.VideoCapture(video_path)
|
| 125 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 126 |
-
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 127 |
frames = []
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
pil_image = Image.fromarray(image)
|
| 136 |
-
timestamp = round(i / fps, 2)
|
| 137 |
-
frames.append((pil_image, timestamp))
|
| 138 |
vidcap.release()
|
| 139 |
return frames
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
elif model_name == "Typhoon-OCR-7B":
|
| 167 |
-
return processor_l, model_l
|
| 168 |
-
elif model_name == "Thyme-RL":
|
| 169 |
-
return processor_n, model_n
|
| 170 |
-
else:
|
| 171 |
-
return None, None
|
| 172 |
-
|
| 173 |
-
@spaces.GPU
|
| 174 |
-
def generate_response(model_name: str, text: str, media_input, media_type: str,
|
| 175 |
-
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
|
| 176 |
-
"""Unified generation function for both image and video."""
|
| 177 |
-
processor, model = get_model_and_processor(model_name)
|
| 178 |
-
if not processor or not model:
|
| 179 |
-
yield "Invalid model selected.", "Invalid model selected."
|
| 180 |
-
return
|
| 181 |
-
|
| 182 |
-
if media_input is None:
|
| 183 |
-
yield f"Please upload a {media_type}.", f"Please upload a {media_type}."
|
| 184 |
-
return
|
| 185 |
-
|
| 186 |
-
if media_type == "video":
|
| 187 |
-
frames = downsample_video(media_input)
|
| 188 |
-
images = [frame for frame, _ in frames]
|
| 189 |
-
else: # image
|
| 190 |
-
images = [media_input]
|
| 191 |
-
|
| 192 |
-
if model_name == "SmolDocling-256M-preview":
|
| 193 |
if "OTSL" in text or "code" in text:
|
| 194 |
images = [add_random_padding(img) for img in images]
|
| 195 |
-
if "OCR at text at" in text or "Identify element" in text or "formula" in text:
|
| 196 |
-
text = normalize_values(text, target_max=500)
|
| 197 |
-
|
| 198 |
-
messages = [
|
| 199 |
-
{"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]}
|
| 200 |
-
]
|
| 201 |
-
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 202 |
-
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
|
| 203 |
-
|
| 204 |
-
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 205 |
-
generation_kwargs = {
|
| 206 |
-
**inputs,
|
| 207 |
-
"streamer": streamer,
|
| 208 |
-
"max_new_tokens": max_new_tokens,
|
| 209 |
-
"temperature": temperature,
|
| 210 |
-
"top_p": top_p,
|
| 211 |
-
"top_k": top_k,
|
| 212 |
-
"repetition_penalty": repetition_penalty,
|
| 213 |
-
}
|
| 214 |
-
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 215 |
-
thread.start()
|
| 216 |
-
|
| 217 |
-
buffer = ""
|
| 218 |
-
for new_text in streamer:
|
| 219 |
-
buffer += new_text.replace("<|im_end|>", "")
|
| 220 |
-
yield buffer, buffer
|
| 221 |
-
|
| 222 |
-
if model_name == "SmolDocling-256M-preview":
|
| 223 |
-
raw_output, formatted_output = format_smoldocling_output(buffer, images)
|
| 224 |
-
yield raw_output, formatted_output
|
| 225 |
-
else:
|
| 226 |
-
# For other models, the formatted output is just the cleaned buffer
|
| 227 |
-
yield buffer, buffer.strip()
|
| 228 |
-
|
| 229 |
-
def generate_image_wrapper(*args):
|
| 230 |
-
yield from generate_response(*args, media_type="image")
|
| 231 |
-
|
| 232 |
-
def generate_video_wrapper(*args):
|
| 233 |
-
yield from generate_response(*args, media_type="video")
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# --- Examples ---
|
| 237 |
-
image_examples = [
|
| 238 |
-
["Reconstruct the doc [table] as it is.", "images/0.png"],
|
| 239 |
-
["Describe the image!", "images/8.png"],
|
| 240 |
-
["OCR the image", "images/2.jpg"],
|
| 241 |
-
["Convert this page to docling", "images/1.png"],
|
| 242 |
-
["Convert this page to docling", "images/3.png"],
|
| 243 |
-
["Convert chart to OTSL.", "images/4.png"],
|
| 244 |
-
["Convert code to text", "images/5.jpg"],
|
| 245 |
-
["Convert this table to OTSL.", "images/6.jpg"],
|
| 246 |
-
["Convert formula to latex.", "images/7.jpg"],
|
| 247 |
-
]
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
|
|
|
|
| 255 |
css = """
|
| 256 |
-
.
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
}
|
| 262 |
-
.submit-btn:hover {
|
| 263 |
-
background-color: #3498db !important;
|
| 264 |
-
box-shadow: 2px 2px 8px rgba(0,0,0,0.3) !important;
|
| 265 |
-
}
|
| 266 |
-
.canvas-output {
|
| 267 |
-
border: 2px solid #4682B4;
|
| 268 |
-
border-radius: 10px;
|
| 269 |
-
padding: 20px;
|
| 270 |
-
background-color: #f0f8ff;
|
| 271 |
}
|
|
|
|
| 272 |
"""
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
with gr.Accordion("📄 Formatted Result (Result.md)", open=True):
|
| 316 |
-
formatted_output = gr.Markdown(label="Formatted Output")
|
| 317 |
-
|
| 318 |
-
model_choice = gr.Radio(
|
| 319 |
-
choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"],
|
| 320 |
-
label="🤖 Select Model",
|
| 321 |
-
value="Nanonets-OCR-s"
|
| 322 |
-
)
|
| 323 |
-
|
| 324 |
-
gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/discussions)")
|
| 325 |
-
gr.Markdown("> **[Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s)**: A powerful, state-of-the-art image-to-markdown OCR model that transforms documents into structured markdown with intelligent content recognition.")
|
| 326 |
-
gr.Markdown("> **[SmolDocling-256M](https://huggingface.co/ds4sd/SmolDocling-256M-preview)**: A multimodal Image-Text-to-Text model designed for efficient document conversion, retaining key features of the larger Docling model.")
|
| 327 |
-
gr.Markdown("> **[MonkeyOCR-Recognition](https://huggingface.co/echo840/MonkeyOCR)**: Adopts a Structure-Recognition-Relation (SRR) paradigm, simplifying the pipeline for document processing.")
|
| 328 |
-
gr.Markdown("> **[Typhoon-OCR-7B](https://huggingface.co/scb10x/typhoon-ocr-7b)**: A bilingual document parsing model for real-world documents in Thai and English, capable of extracting text from images and charts.")
|
| 329 |
-
gr.Markdown("> **[Thyme-RL](https://huggingface.co/Kwai-Keye/Thyme-RL)**: Thyme transcends traditional 'thinking with images' by autonomously generating and executing code for image processing and computation, enhancing performance on complex reasoning tasks.")
|
| 330 |
-
gr.Markdown("> ⚠️ **Note**: All models in this space are primarily optimized for image tasks and may not perform as well on video inference use cases.")
|
| 331 |
-
|
| 332 |
-
# --- Event Handlers ---
|
| 333 |
-
common_inputs = [model_choice, max_new_tokens, temperature, top_p, top_k, repetition_penalty]
|
| 334 |
-
common_outputs = [raw_output, formatted_output]
|
| 335 |
-
|
| 336 |
-
image_submit.click(
|
| 337 |
-
fn=generate_image_wrapper,
|
| 338 |
-
inputs=[image_query, image_upload] + common_inputs,
|
| 339 |
-
outputs=common_outputs
|
| 340 |
)
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
)
|
| 347 |
|
| 348 |
if __name__ == "__main__":
|
| 349 |
-
demo.queue(
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import uuid
|
|
|
|
| 4 |
import time
|
| 5 |
+
import base64
|
| 6 |
+
from http import HTTPStatus
|
| 7 |
from threading import Thread
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
| 14 |
import cv2
|
| 15 |
|
| 16 |
from transformers import (
|
|
|
|
| 17 |
Qwen2_5_VLForConditionalGeneration,
|
|
|
|
| 18 |
AutoModelForVision2Seq,
|
| 19 |
AutoProcessor,
|
| 20 |
TextIteratorStreamer,
|
| 21 |
)
|
| 22 |
+
from gradio_client import utils as client_utils
|
| 23 |
+
import modelscope_studio.components.antd as antd
|
| 24 |
+
import modelscope_studio.components.antdx as antdx
|
| 25 |
+
import modelscope_studio.components.base as ms
|
| 26 |
+
import modelscope_studio.components.pro as pro
|
| 27 |
|
| 28 |
+
# --- Constants and Configuration ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
MAX_MAX_NEW_TOKENS = 5120
|
| 30 |
DEFAULT_MAX_NEW_TOKENS = 3072
|
|
|
|
|
|
|
| 31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 32 |
|
| 33 |
# --- Model Loading ---
|
| 34 |
+
# A dictionary to hold our models and processors for easy access
|
| 35 |
+
models = {}
|
| 36 |
+
processors = {}
|
| 37 |
+
MODEL_CHOICES = [
|
| 38 |
+
"Nanonets-OCR-s",
|
| 39 |
+
"MonkeyOCR-Recognition",
|
| 40 |
+
"Thyme-RL",
|
| 41 |
+
"Typhoon-OCR-7B",
|
| 42 |
+
"SmolDocling-256M-preview"
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''):
|
| 46 |
+
"""Helper function to load a model and processor."""
|
| 47 |
+
print(f"Loading model: {model_key}...")
|
| 48 |
+
try:
|
| 49 |
+
processor_args = {"trust_remote_code": True}
|
| 50 |
+
model_args = {"trust_remote_code": True, "torch_dtype": torch.float16}
|
| 51 |
+
|
| 52 |
+
if subfolder:
|
| 53 |
+
processor_args["subfolder"] = subfolder
|
| 54 |
+
model_args["subfolder"] = subfolder
|
| 55 |
+
|
| 56 |
+
processors[model_key] = processor_class.from_pretrained(model_id, **processor_args)
|
| 57 |
+
models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval()
|
| 58 |
+
print(f"Successfully loaded {model_key}.")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading model {model_key}: {e}")
|
| 61 |
+
# If a model fails to load, remove it from the choices
|
| 62 |
+
if model_key in MODEL_CHOICES:
|
| 63 |
+
MODEL_CHOICES.remove(model_key)
|
| 64 |
+
|
| 65 |
+
# Load all models
|
| 66 |
+
load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s")
|
| 67 |
+
load_model("echo840/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition")
|
| 68 |
+
load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B")
|
| 69 |
+
load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview")
|
| 70 |
+
load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
# --- Preprocessing and Helper Functions ---
|
| 74 |
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
|
| 75 |
+
"""Add random padding to an image."""
|
| 76 |
image = image.convert("RGB")
|
| 77 |
width, height = image.size
|
| 78 |
+
pad_w = int(width * random.uniform(min_percent, max_percent))
|
| 79 |
+
pad_h = int(height * random.uniform(min_percent, max_percent))
|
| 80 |
+
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0)))
|
|
|
|
|
|
|
|
|
|
| 81 |
return padded_image
|
| 82 |
|
| 83 |
+
def downsample_video(video_path, num_frames=10):
|
| 84 |
+
"""Downsample a video into a list of PIL Image frames."""
|
| 85 |
+
if not os.path.exists(video_path): return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
vidcap = cv2.VideoCapture(video_path)
|
| 87 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
| 88 |
frames = []
|
| 89 |
+
if total_frames > 0:
|
| 90 |
+
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
| 91 |
+
for i in frame_indices:
|
| 92 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 93 |
+
success, image = vidcap.read()
|
| 94 |
+
if success:
|
| 95 |
+
frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
|
|
|
|
|
|
|
|
|
|
| 96 |
vidcap.release()
|
| 97 |
return frames
|
| 98 |
|
| 99 |
+
def format_history_for_model(history, selected_model):
|
| 100 |
+
"""Prepares history for the multimodal model, handling text and media files."""
|
| 101 |
+
last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None)
|
| 102 |
+
if not last_user_message:
|
| 103 |
+
return None, [], ""
|
| 104 |
+
|
| 105 |
+
text = ""
|
| 106 |
+
files = []
|
| 107 |
+
images = []
|
| 108 |
+
|
| 109 |
+
for content_part in last_user_message["content"]:
|
| 110 |
+
if content_part["type"] == "text":
|
| 111 |
+
text = content_part["content"]
|
| 112 |
+
elif content_part["type"] == "file":
|
| 113 |
+
files.extend(content_part["content"])
|
| 114 |
+
|
| 115 |
+
for file_path in files:
|
| 116 |
+
mime_type = client_utils.get_mimetype(file_path)
|
| 117 |
+
if mime_type.startswith("image"):
|
| 118 |
+
images.append(Image.open(file_path))
|
| 119 |
+
elif mime_type.startswith("video"):
|
| 120 |
+
images.extend(downsample_video(file_path))
|
| 121 |
+
|
| 122 |
+
# Apply model-specific preprocessing
|
| 123 |
+
if selected_model == "SmolDocling-256M-preview":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if "OTSL" in text or "code" in text:
|
| 125 |
images = [add_random_padding(img) for img in images]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
return text, images, selected_model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# --- Gradio Events and Application Logic ---
|
| 131 |
+
class Gradio_Events:
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def submit(state_value):
|
| 135 |
+
conv_id = state_value["conversation_id"]
|
| 136 |
+
context = state_value["conversation_contexts"][conv_id]
|
| 137 |
+
history = context["history"]
|
| 138 |
+
model_name = context.get("selected_model", MODEL_CHOICES[0])
|
| 139 |
+
|
| 140 |
+
processor = processors.get(model_name)
|
| 141 |
+
model = models.get(model_name)
|
| 142 |
+
|
| 143 |
+
if not processor or not model:
|
| 144 |
+
history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]})
|
| 145 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
text, images, _ = format_history_for_model(history, model_name)
|
| 149 |
+
|
| 150 |
+
if not text and not images:
|
| 151 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
history.append({
|
| 155 |
+
"role": "assistant",
|
| 156 |
+
"content": [],
|
| 157 |
+
"key": str(uuid.uuid4()),
|
| 158 |
+
"loading": True,
|
| 159 |
+
})
|
| 160 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
messages = [{"role": "user", "content": []}]
|
| 164 |
+
if images:
|
| 165 |
+
messages[0]["content"].extend([{"type": "image"}] * len(images))
|
| 166 |
+
messages[0]["content"].append({"type": "text", "text": text or "Describe the media."})
|
| 167 |
+
|
| 168 |
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 169 |
+
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
|
| 170 |
+
|
| 171 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 172 |
+
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS}
|
| 173 |
+
|
| 174 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 175 |
+
thread.start()
|
| 176 |
+
|
| 177 |
+
buffer = ""
|
| 178 |
+
for new_text in streamer:
|
| 179 |
+
buffer += new_text.replace("<|im_end|>", "")
|
| 180 |
+
history[-1]["content"] = [{"type": "text", "content": buffer}]
|
| 181 |
+
history[-1]["loading"] = True
|
| 182 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 183 |
+
|
| 184 |
+
history[-1]["loading"] = False
|
| 185 |
+
# Final post-processing, especially for models like SmolDocling
|
| 186 |
+
final_content = buffer.strip().replace("<end_of_utterance>", "")
|
| 187 |
+
history[-1]["content"] = [{"type": "text", "content": final_content}]
|
| 188 |
+
|
| 189 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"Error during model generation: {e}")
|
| 193 |
+
history[-1]["loading"] = False
|
| 194 |
+
history[-1]["content"] = [{"type": "text", "content": f'<span style="color: red;">An error occurred: {e}</span>'}]
|
| 195 |
+
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def add_message(input_value, state_value):
|
| 199 |
+
text = input_value["text"]
|
| 200 |
+
files = input_value["files"]
|
| 201 |
+
|
| 202 |
+
if not state_value["conversation_id"]:
|
| 203 |
+
random_id = str(uuid.uuid4())
|
| 204 |
+
state_value["conversation_id"] = random_id
|
| 205 |
+
state_value["conversations"].append({"label": text or "New Chat", "key": random_id})
|
| 206 |
+
state_value["conversation_contexts"][random_id] = {
|
| 207 |
+
"history": [],
|
| 208 |
+
"selected_model": MODEL_CHOICES[0] # Default model
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
conv_id = state_value["conversation_id"]
|
| 212 |
+
history = state_value["conversation_contexts"][conv_id]["history"]
|
| 213 |
+
history.append({
|
| 214 |
+
"key": str(uuid.uuid4()),
|
| 215 |
+
"role": "user",
|
| 216 |
+
"content": [{"type": "file", "content": files}, {"type": "text", "content": text}]
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
|
| 220 |
+
for chunk in Gradio_Events.submit(state_value):
|
| 221 |
+
yield chunk
|
| 222 |
+
yield Gradio_Events.postprocess_submit(state_value)
|
| 223 |
+
|
| 224 |
+
@staticmethod
|
| 225 |
+
def preprocess_submit(clear_input=True):
|
| 226 |
+
def handler(state_value):
|
| 227 |
+
conv_id = state_value["conversation_id"]
|
| 228 |
+
history = state_value["conversation_contexts"][conv_id]["history"]
|
| 229 |
+
return {
|
| 230 |
+
input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True),
|
| 231 |
+
conversations: gr.update(active_key=conv_id, items=state_value["conversations"]),
|
| 232 |
+
add_conversation_btn: gr.update(disabled=True),
|
| 233 |
+
chatbot: gr.update(value=history),
|
| 234 |
+
state: gr.update(value=state_value),
|
| 235 |
+
}
|
| 236 |
+
return handler
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def postprocess_submit(state_value):
|
| 240 |
+
conv_id = state_value["conversation_id"]
|
| 241 |
+
history = state_value["conversation_contexts"][conv_id]["history"]
|
| 242 |
+
return {
|
| 243 |
+
input_comp: gr.update(loading=False),
|
| 244 |
+
add_conversation_btn: gr.update(disabled=False),
|
| 245 |
+
chatbot: gr.update(value=history),
|
| 246 |
+
state: gr.update(value=state_value),
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def apply_prompt(e: gr.EventData):
|
| 251 |
+
# Example format: {"description": "Query text", "urls": ["path/to/image.png"]}
|
| 252 |
+
prompt_data = e._data["payload"][0]["value"]
|
| 253 |
+
return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']})
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
def new_chat(state_value):
|
| 257 |
+
state_value["conversation_id"] = ""
|
| 258 |
+
return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0])
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def select_conversation(state_value, e: gr.EventData):
|
| 262 |
+
active_key = e._data["payload"][0]
|
| 263 |
+
if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]:
|
| 264 |
+
return gr.skip()
|
| 265 |
+
state_value["conversation_id"] = active_key
|
| 266 |
+
context = state_value["conversation_contexts"][active_key]
|
| 267 |
+
return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0]))
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def on_model_change(model_name, state_value):
|
| 271 |
+
if state_value["conversation_id"]:
|
| 272 |
+
state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name
|
| 273 |
+
return state_value
|
| 274 |
|
| 275 |
+
|
| 276 |
+
# --- UI Layout and Components ---
|
| 277 |
css = """
|
| 278 |
+
.gradio-container { padding: 0 !important; }
|
| 279 |
+
main.fillable { padding: 0 !important; }
|
| 280 |
+
#chatbot_container { height: calc(100vh - 80px); max-height: 1000px; }
|
| 281 |
+
#conversations_sidebar .chatbot-conversations {
|
| 282 |
+
height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
}
|
| 284 |
+
#main_chat_area { padding: 16px; height: 100%; }
|
| 285 |
"""
|
| 286 |
|
| 287 |
+
# Define welcome prompts based on available examples
|
| 288 |
+
welcome_prompts = [
|
| 289 |
+
{
|
| 290 |
+
"title": "Reconstruct Table",
|
| 291 |
+
"description": "Reconstruct the doc [table] as it is.",
|
| 292 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"title": "Describe Image",
|
| 296 |
+
"description": "Describe the image!",
|
| 297 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"title": "OCR Image",
|
| 301 |
+
"description": "OCR the image",
|
| 302 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"title": "Convert to Docling",
|
| 306 |
+
"description": "Convert this page to docling",
|
| 307 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"]
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"title": "Convert Chart",
|
| 311 |
+
"description": "Convert chart to OTSL.",
|
| 312 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/4.png"]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"title": "Extract Code",
|
| 316 |
+
"description": "Convert code to text",
|
| 317 |
+
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/5.jpg"]
|
| 318 |
+
},
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo:
|
| 322 |
+
state = gr.State({
|
| 323 |
+
"conversation_contexts": {},
|
| 324 |
+
"conversations": [],
|
| 325 |
+
"conversation_id": "",
|
| 326 |
+
})
|
| 327 |
+
|
| 328 |
+
with ms.Application(), antdx.XProvider(), ms.AutoLoading():
|
| 329 |
+
with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"):
|
| 330 |
+
# Left Sidebar for Conversations
|
| 331 |
+
with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"):
|
| 332 |
+
with ms.Div(elem_classes="chatbot-conversations"):
|
| 333 |
+
with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")):
|
| 334 |
+
gr.Markdown("### OCR Conversations")
|
| 335 |
+
with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn:
|
| 336 |
+
ms.Text("New Conversation")
|
| 337 |
+
with ms.Slot("icon"): antd.Icon("PlusOutlined")
|
| 338 |
+
with antdx.Conversations() as conversations:
|
| 339 |
+
pass # Handled by events
|
| 340 |
+
|
| 341 |
+
# Right Main Chat Area
|
| 342 |
+
with antd.Col(flex=1, elem_style=dict(height="100%")):
|
| 343 |
+
with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"):
|
| 344 |
+
gr.Markdown("## Multimodal OCR2")
|
| 345 |
+
chatbot = pro.Chatbot(
|
| 346 |
+
height="calc(100vh - 200px)",
|
| 347 |
+
welcome_config=pro.Chatbot.WelcomeConfig(prompts=welcome_prompts, title="Start by selecting an example:")
|
| 348 |
)
|
| 349 |
+
with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp:
|
| 350 |
+
with ms.Slot("prefix"):
|
| 351 |
+
model_selector = gr.Dropdown(
|
| 352 |
+
choices=MODEL_CHOICES,
|
| 353 |
+
value=MODEL_CHOICES[0],
|
| 354 |
+
label="Select Model",
|
| 355 |
+
container=False
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# --- Event Wiring ---
|
| 359 |
+
add_conversation_btn.click(
|
| 360 |
+
fn=Gradio_Events.new_chat,
|
| 361 |
+
inputs=[state],
|
| 362 |
+
outputs=[conversations, chatbot, state, model_selector]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
)
|
| 364 |
+
conversations.active_change(
|
| 365 |
+
fn=Gradio_Events.select_conversation,
|
| 366 |
+
inputs=[state],
|
| 367 |
+
outputs=[conversations, chatbot, state, model_selector]
|
| 368 |
+
)
|
| 369 |
+
chatbot.welcome_prompt_select(
|
| 370 |
+
fn=Gradio_Events.apply_prompt,
|
| 371 |
+
inputs=[],
|
| 372 |
+
outputs=[input_comp]
|
| 373 |
+
)
|
| 374 |
+
submit_event = input_comp.submit(
|
| 375 |
+
fn=Gradio_Events.add_message,
|
| 376 |
+
inputs=[input_comp, state],
|
| 377 |
+
outputs=[input_comp, add_conversation_btn, conversations, chatbot, state]
|
| 378 |
+
)
|
| 379 |
+
model_selector.change(
|
| 380 |
+
fn=Gradio_Events.on_model_change,
|
| 381 |
+
inputs=[model_selector, state],
|
| 382 |
+
outputs=[state]
|
| 383 |
)
|
| 384 |
|
| 385 |
if __name__ == "__main__":
|
| 386 |
+
demo.queue().launch(show_error=True, debug=True)
|