prithivMLmods commited on
Commit
e61207c
·
verified ·
1 Parent(s): 25a44d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -297
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
  import random
3
  import uuid
4
- import json
5
  import time
6
- import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
@@ -14,336 +14,373 @@ from PIL import Image, ImageOps
14
  import cv2
15
 
16
  from transformers import (
17
- Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
- AutoModelForCausalLM,
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
- from transformers.image_utils import load_image
 
 
 
 
25
 
26
- # These imports seem to be from a custom library.
27
- # If you have 'docling_core' installed, you can uncomment them.
28
- # from docling_core.types.doc import DoclingDocument, DocTagsDocument
29
-
30
- import re
31
- import ast
32
- import html
33
-
34
- # Constants for text generation
35
  MAX_MAX_NEW_TOKENS = 5120
36
  DEFAULT_MAX_NEW_TOKENS = 3072
37
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
38
-
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
  # --- Model Loading ---
42
- # Load Nanonets-OCR-s
43
- MODEL_ID_M = "nanonets/Nanonets-OCR-s"
44
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
45
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
- MODEL_ID_M,
47
- trust_remote_code=True,
48
- torch_dtype=torch.float16
49
- ).to(device).eval()
50
-
51
- # Load MonkeyOCR
52
- MODEL_ID_G = "echo840/MonkeyOCR"
53
- SUBFOLDER = "Recognition"
54
- processor_g = AutoProcessor.from_pretrained(
55
- MODEL_ID_G,
56
- trust_remote_code=True,
57
- subfolder=SUBFOLDER
58
- )
59
- model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
- MODEL_ID_G,
61
- trust_remote_code=True,
62
- subfolder=SUBFOLDER,
63
- torch_dtype=torch.float16
64
- ).to(device).eval()
65
-
66
- # Load Typhoon-OCR-7B
67
- MODEL_ID_L = "scb10x/typhoon-ocr-7b"
68
- processor_l = AutoProcessor.from_pretrained(MODEL_ID_L, trust_remote_code=True)
69
- model_l = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
- MODEL_ID_L,
71
- trust_remote_code=True,
72
- torch_dtype=torch.float16
73
- ).to(device).eval()
74
-
75
- # Load SmolDocling-256M-preview
76
- MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
77
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
78
- model_x = AutoModelForVision2Seq.from_pretrained(
79
- MODEL_ID_X,
80
- trust_remote_code=True,
81
- torch_dtype=torch.float16
82
- ).to(device).eval()
83
-
84
- # Thyme-RL
85
- MODEL_ID_N = "Kwai-Keye/Thyme-RL"
86
- processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
87
- model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
- MODEL_ID_N,
89
- trust_remote_code=True,
90
- torch_dtype=torch.float16
91
- ).to(device).eval()
92
 
93
 
94
  # --- Preprocessing and Helper Functions ---
95
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
96
- """Add random padding to an image based on its size."""
97
  image = image.convert("RGB")
98
  width, height = image.size
99
- pad_w_percent = random.uniform(min_percent, max_percent)
100
- pad_h_percent = random.uniform(min_percent, max_percent)
101
- pad_w = int(width * pad_w_percent)
102
- pad_h = int(height * pad_h_percent)
103
- corner_pixel = image.getpixel((0, 0)) # Top-left corner
104
- padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
105
  return padded_image
106
 
107
- def normalize_values(text, target_max=500):
108
- """Normalize numerical values in text to a target maximum."""
109
- def normalize_list(values):
110
- max_value = max(values) if values else 1
111
- return [round((v / max_value) * target_max) for v in values]
112
-
113
- def process_match(match):
114
- num_list = ast.literal_eval(match.group(0))
115
- normalized = normalize_list(num_list)
116
- return "".join([f"<loc_{num}>" for num in normalized])
117
-
118
- pattern = r"\[([\d\.\s,]+)\]"
119
- normalized_text = re.sub(pattern, process_match, text)
120
- return normalized_text
121
-
122
- def downsample_video(video_path):
123
- """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
124
  vidcap = cv2.VideoCapture(video_path)
125
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
126
- fps = vidcap.get(cv2.CAP_PROP_FPS)
127
  frames = []
128
- # Use 10 frames for video processing
129
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
130
- for i in frame_indices:
131
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
132
- success, image = vidcap.read()
133
- if success:
134
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
135
- pil_image = Image.fromarray(image)
136
- timestamp = round(i / fps, 2)
137
- frames.append((pil_image, timestamp))
138
  vidcap.release()
139
  return frames
140
 
141
- # A placeholder function in case docling_core is not installed
142
- def format_smoldocling_output(buffer_text, images):
143
- cleaned_output = buffer_text.replace("<end_of_utterance>", "").strip()
144
- # Check if docling_core is available and was imported
145
- if 'DocTagsDocument' in globals() and 'DoclingDocument' in globals():
146
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
147
- if "<chart>" in cleaned_output:
148
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
149
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
150
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
151
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
152
- markdown_output = doc.export_to_markdown()
153
- return buffer_text, markdown_output
154
- # Fallback if library is not available or tags are not present
155
- return buffer_text, cleaned_output
156
-
157
- # --- Core Generation Logic ---
158
- def get_model_and_processor(model_name):
159
- """Helper to select model and processor."""
160
- if model_name == "Nanonets-OCR-s":
161
- return processor_m, model_m
162
- elif model_name == "MonkeyOCR-Recognition":
163
- return processor_g, model_g
164
- elif model_name == "SmolDocling-256M-preview":
165
- return processor_x, model_x
166
- elif model_name == "Typhoon-OCR-7B":
167
- return processor_l, model_l
168
- elif model_name == "Thyme-RL":
169
- return processor_n, model_n
170
- else:
171
- return None, None
172
-
173
- @spaces.GPU
174
- def generate_response(model_name: str, text: str, media_input, media_type: str,
175
- max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
176
- """Unified generation function for both image and video."""
177
- processor, model = get_model_and_processor(model_name)
178
- if not processor or not model:
179
- yield "Invalid model selected.", "Invalid model selected."
180
- return
181
-
182
- if media_input is None:
183
- yield f"Please upload a {media_type}.", f"Please upload a {media_type}."
184
- return
185
-
186
- if media_type == "video":
187
- frames = downsample_video(media_input)
188
- images = [frame for frame, _ in frames]
189
- else: # image
190
- images = [media_input]
191
-
192
- if model_name == "SmolDocling-256M-preview":
193
  if "OTSL" in text or "code" in text:
194
  images = [add_random_padding(img) for img in images]
195
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
196
- text = normalize_values(text, target_max=500)
197
-
198
- messages = [
199
- {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]}
200
- ]
201
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
202
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
203
-
204
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
205
- generation_kwargs = {
206
- **inputs,
207
- "streamer": streamer,
208
- "max_new_tokens": max_new_tokens,
209
- "temperature": temperature,
210
- "top_p": top_p,
211
- "top_k": top_k,
212
- "repetition_penalty": repetition_penalty,
213
- }
214
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
215
- thread.start()
216
-
217
- buffer = ""
218
- for new_text in streamer:
219
- buffer += new_text.replace("<|im_end|>", "")
220
- yield buffer, buffer
221
-
222
- if model_name == "SmolDocling-256M-preview":
223
- raw_output, formatted_output = format_smoldocling_output(buffer, images)
224
- yield raw_output, formatted_output
225
- else:
226
- # For other models, the formatted output is just the cleaned buffer
227
- yield buffer, buffer.strip()
228
-
229
- def generate_image_wrapper(*args):
230
- yield from generate_response(*args, media_type="image")
231
-
232
- def generate_video_wrapper(*args):
233
- yield from generate_response(*args, media_type="video")
234
-
235
-
236
- # --- Examples ---
237
- image_examples = [
238
- ["Reconstruct the doc [table] as it is.", "images/0.png"],
239
- ["Describe the image!", "images/8.png"],
240
- ["OCR the image", "images/2.jpg"],
241
- ["Convert this page to docling", "images/1.png"],
242
- ["Convert this page to docling", "images/3.png"],
243
- ["Convert chart to OTSL.", "images/4.png"],
244
- ["Convert code to text", "images/5.jpg"],
245
- ["Convert this table to OTSL.", "images/6.jpg"],
246
- ["Convert formula to latex.", "images/7.jpg"],
247
- ]
248
 
249
- video_examples = [
250
- ["Explain the video in detail.", "videos/1.mp4"],
251
- ["Explain the video in detail.", "videos/2.mp4"]
252
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- # --- UI Styling ---
 
255
  css = """
256
- .submit-btn {
257
- background-color: #2980b9 !important;
258
- color: white !important;
259
- border: none !important;
260
- box-shadow: 2px 2px 5px rgba(0,0,0,0.2) !important;
261
- }
262
- .submit-btn:hover {
263
- background-color: #3498db !important;
264
- box-shadow: 2px 2px 8px rgba(0,0,0,0.3) !important;
265
- }
266
- .canvas-output {
267
- border: 2px solid #4682B4;
268
- border-radius: 10px;
269
- padding: 20px;
270
- background-color: #f0f8ff;
271
  }
 
272
  """
273
 
274
- # --- Gradio Interface ---
275
- with gr.Blocks(css=css) as demo:
276
- gr.Markdown("# **[Multimodal OCR2](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
277
-
278
- with gr.Row():
279
- # Left Column for Inputs and Controls
280
- with gr.Column(scale=1):
281
- with gr.Tabs():
282
- with gr.TabItem("🖼️ Image Inference"):
283
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
284
- image_upload = gr.Image(type="pil", label="Upload Image", height=300)
285
- gr.Examples(
286
- examples=image_examples,
287
- inputs=[image_query, image_upload],
288
- label="Image Examples"
289
- )
290
- image_submit = gr.Button("Submit", elem_classes="submit-btn")
291
-
292
- with gr.TabItem("🎥 Video Inference"):
293
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
294
- video_upload = gr.Video(label="Upload Video", height=300)
295
- gr.Examples(
296
- examples=video_examples,
297
- inputs=[video_query, video_upload],
298
- label="Video Examples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  )
300
- video_submit = gr.Button("Submit", elem_classes="submit-btn")
301
-
302
- with gr.Accordion("⚙️ Advanced Options", open=False):
303
- max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
304
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
305
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
306
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
307
- repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
308
-
309
- # Right Column for Outputs and Model Info
310
- with gr.Column(scale=1):
311
- with gr.Column(elem_classes="canvas-output"):
312
- gr.Markdown("## Output")
313
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=5)
314
-
315
- with gr.Accordion("📄 Formatted Result (Result.md)", open=True):
316
- formatted_output = gr.Markdown(label="Formatted Output")
317
-
318
- model_choice = gr.Radio(
319
- choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"],
320
- label="🤖 Select Model",
321
- value="Nanonets-OCR-s"
322
- )
323
-
324
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/discussions)")
325
- gr.Markdown("> **[Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s)**: A powerful, state-of-the-art image-to-markdown OCR model that transforms documents into structured markdown with intelligent content recognition.")
326
- gr.Markdown("> **[SmolDocling-256M](https://huggingface.co/ds4sd/SmolDocling-256M-preview)**: A multimodal Image-Text-to-Text model designed for efficient document conversion, retaining key features of the larger Docling model.")
327
- gr.Markdown("> **[MonkeyOCR-Recognition](https://huggingface.co/echo840/MonkeyOCR)**: Adopts a Structure-Recognition-Relation (SRR) paradigm, simplifying the pipeline for document processing.")
328
- gr.Markdown("> **[Typhoon-OCR-7B](https://huggingface.co/scb10x/typhoon-ocr-7b)**: A bilingual document parsing model for real-world documents in Thai and English, capable of extracting text from images and charts.")
329
- gr.Markdown("> **[Thyme-RL](https://huggingface.co/Kwai-Keye/Thyme-RL)**: Thyme transcends traditional 'thinking with images' by autonomously generating and executing code for image processing and computation, enhancing performance on complex reasoning tasks.")
330
- gr.Markdown("> ⚠️ **Note**: All models in this space are primarily optimized for image tasks and may not perform as well on video inference use cases.")
331
-
332
- # --- Event Handlers ---
333
- common_inputs = [model_choice, max_new_tokens, temperature, top_p, top_k, repetition_penalty]
334
- common_outputs = [raw_output, formatted_output]
335
-
336
- image_submit.click(
337
- fn=generate_image_wrapper,
338
- inputs=[image_query, image_upload] + common_inputs,
339
- outputs=common_outputs
340
  )
341
-
342
- video_submit.click(
343
- fn=generate_video_wrapper,
344
- inputs=[video_query, video_upload] + common_inputs,
345
- outputs=common_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  )
347
 
348
  if __name__ == "__main__":
349
- demo.queue(max_size=50).launch(share=True, show_error=True)
 
1
  import os
2
  import random
3
  import uuid
 
4
  import time
5
+ import base64
6
+ from http import HTTPStatus
7
  from threading import Thread
8
 
9
  import gradio as gr
 
14
  import cv2
15
 
16
  from transformers import (
 
17
  Qwen2_5_VLForConditionalGeneration,
 
18
  AutoModelForVision2Seq,
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
22
+ from gradio_client import utils as client_utils
23
+ import modelscope_studio.components.antd as antd
24
+ import modelscope_studio.components.antdx as antdx
25
+ import modelscope_studio.components.base as ms
26
+ import modelscope_studio.components.pro as pro
27
 
28
+ # --- Constants and Configuration ---
 
 
 
 
 
 
 
 
29
  MAX_MAX_NEW_TOKENS = 5120
30
  DEFAULT_MAX_NEW_TOKENS = 3072
 
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
  # --- Model Loading ---
34
+ # A dictionary to hold our models and processors for easy access
35
+ models = {}
36
+ processors = {}
37
+ MODEL_CHOICES = [
38
+ "Nanonets-OCR-s",
39
+ "MonkeyOCR-Recognition",
40
+ "Thyme-RL",
41
+ "Typhoon-OCR-7B",
42
+ "SmolDocling-256M-preview"
43
+ ]
44
+
45
+ def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''):
46
+ """Helper function to load a model and processor."""
47
+ print(f"Loading model: {model_key}...")
48
+ try:
49
+ processor_args = {"trust_remote_code": True}
50
+ model_args = {"trust_remote_code": True, "torch_dtype": torch.float16}
51
+
52
+ if subfolder:
53
+ processor_args["subfolder"] = subfolder
54
+ model_args["subfolder"] = subfolder
55
+
56
+ processors[model_key] = processor_class.from_pretrained(model_id, **processor_args)
57
+ models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval()
58
+ print(f"Successfully loaded {model_key}.")
59
+ except Exception as e:
60
+ print(f"Error loading model {model_key}: {e}")
61
+ # If a model fails to load, remove it from the choices
62
+ if model_key in MODEL_CHOICES:
63
+ MODEL_CHOICES.remove(model_key)
64
+
65
+ # Load all models
66
+ load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s")
67
+ load_model("echo840/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition")
68
+ load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B")
69
+ load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview")
70
+ load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL")
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  # --- Preprocessing and Helper Functions ---
74
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
75
+ """Add random padding to an image."""
76
  image = image.convert("RGB")
77
  width, height = image.size
78
+ pad_w = int(width * random.uniform(min_percent, max_percent))
79
+ pad_h = int(height * random.uniform(min_percent, max_percent))
80
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0)))
 
 
 
81
  return padded_image
82
 
83
+ def downsample_video(video_path, num_frames=10):
84
+ """Downsample a video into a list of PIL Image frames."""
85
+ if not os.path.exists(video_path): return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  vidcap = cv2.VideoCapture(video_path)
87
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
88
  frames = []
89
+ if total_frames > 0:
90
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
91
+ for i in frame_indices:
92
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
93
+ success, image = vidcap.read()
94
+ if success:
95
+ frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
 
 
 
96
  vidcap.release()
97
  return frames
98
 
99
+ def format_history_for_model(history, selected_model):
100
+ """Prepares history for the multimodal model, handling text and media files."""
101
+ last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None)
102
+ if not last_user_message:
103
+ return None, [], ""
104
+
105
+ text = ""
106
+ files = []
107
+ images = []
108
+
109
+ for content_part in last_user_message["content"]:
110
+ if content_part["type"] == "text":
111
+ text = content_part["content"]
112
+ elif content_part["type"] == "file":
113
+ files.extend(content_part["content"])
114
+
115
+ for file_path in files:
116
+ mime_type = client_utils.get_mimetype(file_path)
117
+ if mime_type.startswith("image"):
118
+ images.append(Image.open(file_path))
119
+ elif mime_type.startswith("video"):
120
+ images.extend(downsample_video(file_path))
121
+
122
+ # Apply model-specific preprocessing
123
+ if selected_model == "SmolDocling-256M-preview":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if "OTSL" in text or "code" in text:
125
  images = [add_random_padding(img) for img in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ return text, images, selected_model
128
+
129
+
130
+ # --- Gradio Events and Application Logic ---
131
+ class Gradio_Events:
132
+
133
+ @staticmethod
134
+ def submit(state_value):
135
+ conv_id = state_value["conversation_id"]
136
+ context = state_value["conversation_contexts"][conv_id]
137
+ history = context["history"]
138
+ model_name = context.get("selected_model", MODEL_CHOICES[0])
139
+
140
+ processor = processors.get(model_name)
141
+ model = models.get(model_name)
142
+
143
+ if not processor or not model:
144
+ history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]})
145
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
146
+ return
147
+
148
+ text, images, _ = format_history_for_model(history, model_name)
149
+
150
+ if not text and not images:
151
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
152
+ return
153
+
154
+ history.append({
155
+ "role": "assistant",
156
+ "content": [],
157
+ "key": str(uuid.uuid4()),
158
+ "loading": True,
159
+ })
160
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
161
+
162
+ try:
163
+ messages = [{"role": "user", "content": []}]
164
+ if images:
165
+ messages[0]["content"].extend([{"type": "image"}] * len(images))
166
+ messages[0]["content"].append({"type": "text", "text": text or "Describe the media."})
167
+
168
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
169
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
170
+
171
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
172
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS}
173
+
174
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
+ thread.start()
176
+
177
+ buffer = ""
178
+ for new_text in streamer:
179
+ buffer += new_text.replace("<|im_end|>", "")
180
+ history[-1]["content"] = [{"type": "text", "content": buffer}]
181
+ history[-1]["loading"] = True
182
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
183
+
184
+ history[-1]["loading"] = False
185
+ # Final post-processing, especially for models like SmolDocling
186
+ final_content = buffer.strip().replace("<end_of_utterance>", "")
187
+ history[-1]["content"] = [{"type": "text", "content": final_content}]
188
+
189
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
190
+
191
+ except Exception as e:
192
+ print(f"Error during model generation: {e}")
193
+ history[-1]["loading"] = False
194
+ history[-1]["content"] = [{"type": "text", "content": f'<span style="color: red;">An error occurred: {e}</span>'}]
195
+ yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
196
+
197
+ @staticmethod
198
+ def add_message(input_value, state_value):
199
+ text = input_value["text"]
200
+ files = input_value["files"]
201
+
202
+ if not state_value["conversation_id"]:
203
+ random_id = str(uuid.uuid4())
204
+ state_value["conversation_id"] = random_id
205
+ state_value["conversations"].append({"label": text or "New Chat", "key": random_id})
206
+ state_value["conversation_contexts"][random_id] = {
207
+ "history": [],
208
+ "selected_model": MODEL_CHOICES[0] # Default model
209
+ }
210
+
211
+ conv_id = state_value["conversation_id"]
212
+ history = state_value["conversation_contexts"][conv_id]["history"]
213
+ history.append({
214
+ "key": str(uuid.uuid4()),
215
+ "role": "user",
216
+ "content": [{"type": "file", "content": files}, {"type": "text", "content": text}]
217
+ })
218
+
219
+ yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
220
+ for chunk in Gradio_Events.submit(state_value):
221
+ yield chunk
222
+ yield Gradio_Events.postprocess_submit(state_value)
223
+
224
+ @staticmethod
225
+ def preprocess_submit(clear_input=True):
226
+ def handler(state_value):
227
+ conv_id = state_value["conversation_id"]
228
+ history = state_value["conversation_contexts"][conv_id]["history"]
229
+ return {
230
+ input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True),
231
+ conversations: gr.update(active_key=conv_id, items=state_value["conversations"]),
232
+ add_conversation_btn: gr.update(disabled=True),
233
+ chatbot: gr.update(value=history),
234
+ state: gr.update(value=state_value),
235
+ }
236
+ return handler
237
+
238
+ @staticmethod
239
+ def postprocess_submit(state_value):
240
+ conv_id = state_value["conversation_id"]
241
+ history = state_value["conversation_contexts"][conv_id]["history"]
242
+ return {
243
+ input_comp: gr.update(loading=False),
244
+ add_conversation_btn: gr.update(disabled=False),
245
+ chatbot: gr.update(value=history),
246
+ state: gr.update(value=state_value),
247
+ }
248
+
249
+ @staticmethod
250
+ def apply_prompt(e: gr.EventData):
251
+ # Example format: {"description": "Query text", "urls": ["path/to/image.png"]}
252
+ prompt_data = e._data["payload"][0]["value"]
253
+ return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']})
254
+
255
+ @staticmethod
256
+ def new_chat(state_value):
257
+ state_value["conversation_id"] = ""
258
+ return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0])
259
+
260
+ @staticmethod
261
+ def select_conversation(state_value, e: gr.EventData):
262
+ active_key = e._data["payload"][0]
263
+ if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]:
264
+ return gr.skip()
265
+ state_value["conversation_id"] = active_key
266
+ context = state_value["conversation_contexts"][active_key]
267
+ return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0]))
268
+
269
+ @staticmethod
270
+ def on_model_change(model_name, state_value):
271
+ if state_value["conversation_id"]:
272
+ state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name
273
+ return state_value
274
 
275
+
276
+ # --- UI Layout and Components ---
277
  css = """
278
+ .gradio-container { padding: 0 !important; }
279
+ main.fillable { padding: 0 !important; }
280
+ #chatbot_container { height: calc(100vh - 80px); max-height: 1000px; }
281
+ #conversations_sidebar .chatbot-conversations {
282
+ height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px;
 
 
 
 
 
 
 
 
 
 
283
  }
284
+ #main_chat_area { padding: 16px; height: 100%; }
285
  """
286
 
287
+ # Define welcome prompts based on available examples
288
+ welcome_prompts = [
289
+ {
290
+ "title": "Reconstruct Table",
291
+ "description": "Reconstruct the doc [table] as it is.",
292
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"]
293
+ },
294
+ {
295
+ "title": "Describe Image",
296
+ "description": "Describe the image!",
297
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"]
298
+ },
299
+ {
300
+ "title": "OCR Image",
301
+ "description": "OCR the image",
302
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"]
303
+ },
304
+ {
305
+ "title": "Convert to Docling",
306
+ "description": "Convert this page to docling",
307
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"]
308
+ },
309
+ {
310
+ "title": "Convert Chart",
311
+ "description": "Convert chart to OTSL.",
312
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/4.png"]
313
+ },
314
+ {
315
+ "title": "Extract Code",
316
+ "description": "Convert code to text",
317
+ "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/5.jpg"]
318
+ },
319
+ ]
320
+
321
+ with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo:
322
+ state = gr.State({
323
+ "conversation_contexts": {},
324
+ "conversations": [],
325
+ "conversation_id": "",
326
+ })
327
+
328
+ with ms.Application(), antdx.XProvider(), ms.AutoLoading():
329
+ with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"):
330
+ # Left Sidebar for Conversations
331
+ with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"):
332
+ with ms.Div(elem_classes="chatbot-conversations"):
333
+ with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")):
334
+ gr.Markdown("### OCR Conversations")
335
+ with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn:
336
+ ms.Text("New Conversation")
337
+ with ms.Slot("icon"): antd.Icon("PlusOutlined")
338
+ with antdx.Conversations() as conversations:
339
+ pass # Handled by events
340
+
341
+ # Right Main Chat Area
342
+ with antd.Col(flex=1, elem_style=dict(height="100%")):
343
+ with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"):
344
+ gr.Markdown("## Multimodal OCR2")
345
+ chatbot = pro.Chatbot(
346
+ height="calc(100vh - 200px)",
347
+ welcome_config=pro.Chatbot.WelcomeConfig(prompts=welcome_prompts, title="Start by selecting an example:")
348
  )
349
+ with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp:
350
+ with ms.Slot("prefix"):
351
+ model_selector = gr.Dropdown(
352
+ choices=MODEL_CHOICES,
353
+ value=MODEL_CHOICES[0],
354
+ label="Select Model",
355
+ container=False
356
+ )
357
+
358
+ # --- Event Wiring ---
359
+ add_conversation_btn.click(
360
+ fn=Gradio_Events.new_chat,
361
+ inputs=[state],
362
+ outputs=[conversations, chatbot, state, model_selector]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  )
364
+ conversations.active_change(
365
+ fn=Gradio_Events.select_conversation,
366
+ inputs=[state],
367
+ outputs=[conversations, chatbot, state, model_selector]
368
+ )
369
+ chatbot.welcome_prompt_select(
370
+ fn=Gradio_Events.apply_prompt,
371
+ inputs=[],
372
+ outputs=[input_comp]
373
+ )
374
+ submit_event = input_comp.submit(
375
+ fn=Gradio_Events.add_message,
376
+ inputs=[input_comp, state],
377
+ outputs=[input_comp, add_conversation_btn, conversations, chatbot, state]
378
+ )
379
+ model_selector.change(
380
+ fn=Gradio_Events.on_model_change,
381
+ inputs=[model_selector, state],
382
+ outputs=[state]
383
  )
384
 
385
  if __name__ == "__main__":
386
+ demo.queue().launch(show_error=True, debug=True)