prithivMLmods commited on
Commit
2acc319
·
verified ·
1 Parent(s): d481cbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -331
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  import random
3
  import uuid
4
- import spaces
5
  import time
6
- import base64
7
- from http import HTTPStatus
8
  from threading import Thread
9
 
10
  import gradio as gr
@@ -15,369 +14,358 @@ from PIL import Image, ImageOps
15
  import cv2
16
 
17
  from transformers import (
 
18
  Qwen2_5_VLForConditionalGeneration,
 
19
  AutoModelForVision2Seq,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
23
- from gradio_client import utils as client_utils
24
- import modelscope_studio.components.antd as antd
25
- import modelscope_studio.components.antdx as antdx
26
- import modelscope_studio.components.base as ms
27
- import modelscope_studio.components.pro as pro
28
 
29
- # --- Constants and Configuration ---
 
 
 
 
30
  MAX_MAX_NEW_TOKENS = 5120
31
  DEFAULT_MAX_NEW_TOKENS = 3072
 
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
 
34
  # --- Model Loading ---
35
- # A dictionary to hold our models and processors for easy access
36
- models = {}
37
- processors = {}
38
- MODEL_CHOICES = [
39
- "Nanonets-OCR-s",
40
- "MonkeyOCR-Recognition",
41
- "Thyme-RL",
42
- "Typhoon-OCR-7B",
43
- "SmolDocling-256M-preview"
44
- ]
45
-
46
- def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''):
47
- """Helper function to load a model and processor."""
48
- print(f"Loading model: {model_key}...")
49
- try:
50
- processor_args = {"trust_remote_code": True}
51
- model_args = {"trust_remote_code": True, "torch_dtype": torch.float16}
52
-
53
- if subfolder:
54
- processor_args["subfolder"] = subfolder
55
- model_args["subfolder"] = subfolder
56
-
57
- processors[model_key] = processor_class.from_pretrained(model_id, **processor_args)
58
- models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval()
59
- print(f"Successfully loaded {model_key}.")
60
- except Exception as e:
61
- print(f"Error loading model {model_key}: {e}")
62
- # If a model fails to load, remove it from the choices
63
- if model_key in MODEL_CHOICES:
64
- MODEL_CHOICES.remove(model_key)
65
-
66
- # Load all models
67
- load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s")
68
- load_model("echo80/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition")
69
- load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B")
70
- load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview")
71
- load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL")
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  # --- Preprocessing and Helper Functions ---
75
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
76
- """Add random padding to an image."""
77
  image = image.convert("RGB")
78
  width, height = image.size
79
- pad_w = int(width * random.uniform(min_percent, max_percent))
80
- pad_h = int(height * random.uniform(min_percent, max_percent))
81
- padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0)))
 
 
 
82
  return padded_image
83
 
84
- def downsample_video(video_path, num_frames=10):
85
- """Downsample a video into a list of PIL Image frames."""
86
- if not os.path.exists(video_path): return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  vidcap = cv2.VideoCapture(video_path)
88
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
89
  frames = []
90
- if total_frames > 0:
91
- frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
92
- for i in frame_indices:
93
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
94
- success, image = vidcap.read()
95
- if success:
96
- frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
 
 
 
97
  vidcap.release()
98
  return frames
99
 
100
- def format_history_for_model(history, selected_model):
101
- """Prepares history for the multimodal model, handling text and media files."""
102
- last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None)
103
- if not last_user_message:
104
- return None, [], ""
105
-
106
- text = ""
107
- files = []
108
- images = []
109
-
110
- for content_part in last_user_message["content"]:
111
- if content_part["type"] == "text":
112
- text = content_part["content"]
113
- elif content_part["type"] == "file":
114
- files.extend(content_part["content"])
115
-
116
- for file_path in files:
117
- mime_type = client_utils.get_mimetype(file_path)
118
- if mime_type.startswith("image"):
119
- images.append(Image.open(file_path))
120
- elif mime_type.startswith("video"):
121
- images.extend(downsample_video(file_path))
122
-
123
- # Apply model-specific preprocessing
124
- if selected_model == "SmolDocling-256M-preview":
125
- if "OTSL" in text or "code" in text:
126
- images = [add_random_padding(img) for img in images]
127
-
128
- return text, images, selected_model
 
 
 
 
 
 
 
 
 
129
 
130
  @spaces.GPU
131
- # --- Gradio Events and Application Logic ---
132
- class Gradio_Events:
133
-
134
- @staticmethod
135
- def submit(state_value):
136
- conv_id = state_value["conversation_id"]
137
- context = state_value["conversation_contexts"][conv_id]
138
- history = context["history"]
139
- model_name = context.get("selected_model", MODEL_CHOICES[0])
140
-
141
- processor = processors.get(model_name)
142
- model = models.get(model_name)
143
-
144
- if not processor or not model:
145
- history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]})
146
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
147
- return
148
-
149
- text, images, _ = format_history_for_model(history, model_name)
150
-
151
- if not text and not images:
152
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
153
- return
154
-
155
- history.append({
156
- "role": "assistant",
157
- "content": [],
158
- "key": str(uuid.uuid4()),
159
- "loading": True,
160
- })
161
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
162
-
163
- try:
164
- messages = [{"role": "user", "content": []}]
165
- if images:
166
- messages[0]["content"].extend([{"type": "image"}] * len(images))
167
- messages[0]["content"].append({"type": "text", "text": text or "Describe the media."})
168
-
169
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
170
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
171
-
172
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
173
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS}
174
-
175
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
176
- thread.start()
177
-
178
- buffer = ""
179
- for new_text in streamer:
180
- buffer += new_text.replace("<|im_end|>", "")
181
- history[-1]["content"] = [{"type": "text", "content": buffer}]
182
- history[-1]["loading"] = True
183
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
184
-
185
- history[-1]["loading"] = False
186
- # Final post-processing, especially for models like SmolDocling
187
- final_content = buffer.strip().replace("<end_of_utterance>", "")
188
- history[-1]["content"] = [{"type": "text", "content": final_content}]
189
-
190
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
191
-
192
- except Exception as e:
193
- print(f"Error during model generation: {e}")
194
- history[-1]["loading"] = False
195
- history[-1]["content"] = [{"type": "text", "content": f'<span style="color: red;">An error occurred: {e}</span>'}]
196
- yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
197
-
198
- @staticmethod
199
- def add_message(input_value, state_value):
200
- text = input_value["text"]
201
- files = input_value["files"]
202
-
203
- if not state_value["conversation_id"]:
204
- random_id = str(uuid.uuid4())
205
- state_value["conversation_id"] = random_id
206
- state_value["conversations"].append({"label": text or "New Chat", "key": random_id})
207
- state_value["conversation_contexts"][random_id] = {
208
- "history": [],
209
- "selected_model": MODEL_CHOICES[0] # Default model
210
- }
211
-
212
- conv_id = state_value["conversation_id"]
213
- history = state_value["conversation_contexts"][conv_id]["history"]
214
- history.append({
215
- "key": str(uuid.uuid4()),
216
- "role": "user",
217
- "content": [{"type": "file", "content": files}, {"type": "text", "content": text}]
218
- })
219
-
220
- yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
221
- for chunk in Gradio_Events.submit(state_value):
222
- yield chunk
223
- yield Gradio_Events.postprocess_submit(state_value)
224
-
225
- @staticmethod
226
- def preprocess_submit(clear_input=True):
227
- def handler(state_value):
228
- conv_id = state_value["conversation_id"]
229
- history = state_value["conversation_contexts"][conv_id]["history"]
230
- return {
231
- input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True),
232
- conversations: gr.update(active_key=conv_id, items=state_value["conversations"]),
233
- add_conversation_btn: gr.update(disabled=True),
234
- chatbot: gr.update(value=history),
235
- state: gr.update(value=state_value),
236
- }
237
- return handler
238
-
239
- @staticmethod
240
- def postprocess_submit(state_value):
241
- conv_id = state_value["conversation_id"]
242
- history = state_value["conversation_contexts"][conv_id]["history"]
243
- return {
244
- input_comp: gr.update(loading=False),
245
- add_conversation_btn: gr.update(disabled=False),
246
- chatbot: gr.update(value=history),
247
- state: gr.update(value=state_value),
248
- }
249
-
250
- @staticmethod
251
- def apply_prompt(e: gr.EventData):
252
- # Example format: {"description": "Query text", "urls": ["path/to/image.png"]}
253
- prompt_data = e._data["payload"][0]["value"]
254
- return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']})
255
-
256
- @staticmethod
257
- def new_chat(state_value):
258
- state_value["conversation_id"] = ""
259
- return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0])
260
-
261
- @staticmethod
262
- def select_conversation(state_value, e: gr.EventData):
263
- active_key = e._data["payload"][0]
264
- if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]:
265
- return gr.skip()
266
- state_value["conversation_id"] = active_key
267
- context = state_value["conversation_contexts"][active_key]
268
- return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0]))
269
-
270
- @staticmethod
271
- def on_model_change(model_name, state_value):
272
- if state_value["conversation_id"]:
273
- state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name
274
- return state_value
275
 
276
 
277
- # --- UI Layout and Components ---
278
  css = """
279
- .gradio-container { padding: 0 !important; }
280
- main.fillable { padding: 0 !important; }
281
- #chatbot_container { height: calc(100vh - 80px); max-height: 1000px; }
282
- #conversations_sidebar .chatbot-conversations {
283
- height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px;
284
- }
285
- #main_chat_area { padding: 16px; height: 100%; }
 
286
  """
287
 
288
- # Define welcome prompts based on available examples
289
- welcome_prompts = [
290
- {
291
- "title": "Reconstruct Table",
292
- "description": "Reconstruct the doc [table] as it is.",
293
- "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"]
294
- },
295
- {
296
- "title": "Describe Image",
297
- "description": "Describe the image!",
298
- "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"]
299
- },
300
- {
301
- "title": "OCR Image",
302
- "description": "OCR the image",
303
- "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"]
304
- },
305
- {
306
- "title": "Convert to Docling",
307
- "description": "Convert this page to docling",
308
- "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"]
309
- },
310
- ]
311
 
312
- with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo:
313
- state = gr.State({
314
- "conversation_contexts": {},
315
- "conversations": [],
316
- "conversation_id": "",
317
- })
318
-
319
- with ms.Application(), antdx.XProvider(), ms.AutoLoading():
320
- with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"):
321
- # Left Sidebar for Conversations
322
- with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"):
323
- with ms.Div(elem_classes="chatbot-conversations"):
324
- with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")):
325
- gr.Markdown("### OCR Conversations")
326
- with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn:
327
- ms.Text("New Conversation")
328
- with ms.Slot("icon"): antd.Icon("PlusOutlined")
329
- with antdx.Conversations() as conversations:
330
- pass # Handled by events
331
-
332
- # Right Main Chat Area
333
- with antd.Col(flex=1, elem_style=dict(height="100%")):
334
- with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"):
335
- gr.Markdown("## Multimodal OCR2")
336
- chatbot = pro.Chatbot(
337
- height="calc(100vh - 200px)",
338
- # FIX: The `prompts` key now holds a dictionary for categorization
339
- welcome_config={
340
- "prompts": {
341
- "Examples": welcome_prompts
342
- },
343
- "title": "Start by selecting an example:"
344
- }
345
- )
346
- with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp:
347
- with ms.Slot("prefix"):
348
- model_selector = gr.Dropdown(
349
- choices=MODEL_CHOICES,
350
- value=MODEL_CHOICES[0],
351
- label="Select Model",
352
- container=False
353
- )
354
-
355
- # --- Event Wiring ---
356
- add_conversation_btn.click(
357
- fn=Gradio_Events.new_chat,
358
- inputs=[state],
359
- outputs=[conversations, chatbot, state, model_selector]
360
- )
361
- conversations.active_change(
362
- fn=Gradio_Events.select_conversation,
363
- inputs=[state],
364
- outputs=[conversations, chatbot, state, model_selector]
365
- )
366
- chatbot.welcome_prompt_select(
367
- fn=Gradio_Events.apply_prompt,
368
- inputs=[],
369
- outputs=[input_comp]
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  )
371
- submit_event = input_comp.submit(
372
- fn=Gradio_Events.add_message,
373
- inputs=[input_comp, state],
374
- outputs=[input_comp, add_conversation_btn, conversations, chatbot, state]
 
 
 
 
375
  )
376
- model_selector.change(
377
- fn=Gradio_Events.on_model_change,
378
- inputs=[model_selector, state],
379
- outputs=[state]
380
  )
381
 
 
382
  if __name__ == "__main__":
383
- demo.queue().launch(show_error=True, debug=True)
 
1
  import os
2
  import random
3
  import uuid
4
+ import json
5
  import time
6
+ import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
 
14
  import cv2
15
 
16
  from transformers import (
17
+ Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
+ AutoModelForCausalLM,
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
+ from transformers.image_utils import load_image
25
+
26
+ # These imports seem to be from a custom library.
27
+ # If you have 'docling_core' installed, you can uncomment them.
28
+ # from docling_core.types.doc import DoclingDocument, DocTagsDocument
29
 
30
+ import re
31
+ import ast
32
+ import html
33
+
34
+ # --- Constants ---
35
  MAX_MAX_NEW_TOKENS = 5120
36
  DEFAULT_MAX_NEW_TOKENS = 3072
37
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
38
+
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
  # --- Model Loading ---
42
+ # Load Nanonets-OCR-s
43
+ MODEL_ID_M = "nanonets/Nanonets-OCR-s"
44
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
45
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
+ MODEL_ID_M,
47
+ trust_remote_code=True,
48
+ torch_dtype=torch.float16
49
+ ).to(device).eval()
50
+
51
+ # Load MonkeyOCR
52
+ MODEL_ID_G = "echo840/MonkeyOCR"
53
+ SUBFOLDER = "Recognition"
54
+ processor_g = AutoProcessor.from_pretrained(
55
+ MODEL_ID_G,
56
+ trust_remote_code=True,
57
+ subfolder=SUBFOLDER
58
+ )
59
+ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
+ MODEL_ID_G,
61
+ trust_remote_code=True,
62
+ subfolder=SUBFOLDER,
63
+ torch_dtype=torch.float16
64
+ ).to(device).eval()
65
+
66
+ # Load Typhoon-OCR-7B
67
+ MODEL_ID_L = "scb10x/typhoon-ocr-7b"
68
+ processor_l = AutoProcessor.from_pretrained(MODEL_ID_L, trust_remote_code=True)
69
+ model_l = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
+ MODEL_ID_L,
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16
73
+ ).to(device).eval()
74
+
75
+ # Load SmolDocling-256M-preview
76
+ MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
77
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
78
+ model_x = AutoModelForVision2Seq.from_pretrained(
79
+ MODEL_ID_X,
80
+ trust_remote_code=True,
81
+ torch_dtype=torch.float16
82
+ ).to(device).eval()
83
+
84
+ # Thyme-RL
85
+ MODEL_ID_N = "Kwai-Keye/Thyme-RL"
86
+ processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
87
+ model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
+ MODEL_ID_N,
89
+ trust_remote_code=True,
90
+ torch_dtype=torch.float16
91
+ ).to(device).eval()
92
 
93
 
94
  # --- Preprocessing and Helper Functions ---
95
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
96
+ """Add random padding to an image based on its size."""
97
  image = image.convert("RGB")
98
  width, height = image.size
99
+ pad_w_percent = random.uniform(min_percent, max_percent)
100
+ pad_h_percent = random.uniform(min_percent, max_percent)
101
+ pad_w = int(width * pad_w_percent)
102
+ pad_h = int(height * pad_h_percent)
103
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
104
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
105
  return padded_image
106
 
107
+ def normalize_values(text, target_max=500):
108
+ """Normalize numerical values in text to a target maximum."""
109
+ def normalize_list(values):
110
+ max_value = max(values) if values else 1
111
+ return [round((v / max_value) * target_max) for v in values]
112
+
113
+ def process_match(match):
114
+ num_list = ast.literal_eval(match.group(0))
115
+ normalized = normalize_list(num_list)
116
+ return "".join([f"<loc_{num}>" for num in normalized])
117
+
118
+ pattern = r"\[([\d\.\s,]+)\]"
119
+ normalized_text = re.sub(pattern, process_match, text)
120
+ return normalized_text
121
+
122
+ def downsample_video(video_path):
123
+ """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
124
  vidcap = cv2.VideoCapture(video_path)
125
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
126
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
127
  frames = []
128
+ # Use 10 frames for video processing
129
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
130
+ for i in frame_indices:
131
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
132
+ success, image = vidcap.read()
133
+ if success:
134
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
135
+ pil_image = Image.fromarray(image)
136
+ timestamp = round(i / fps, 2)
137
+ frames.append((pil_image, timestamp))
138
  vidcap.release()
139
  return frames
140
 
141
+ # A placeholder function in case docling_core is not installed
142
+ def format_smoldocling_output(buffer_text, images):
143
+ cleaned_output = buffer_text.replace("<end_of_utterance>", "").strip()
144
+ # Check if docling_core is available and was imported
145
+ if 'DocTagsDocument' in globals() and 'DoclingDocument' in globals():
146
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
147
+ if "<chart>" in cleaned_output:
148
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
149
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
150
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
151
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
152
+ markdown_output = doc.export_to_markdown()
153
+ return markdown_output
154
+ # Fallback if library is not available or tags are not present
155
+ return cleaned_output
156
+
157
+ # --- Core Generation Logic ---
158
+ def get_model_and_processor(model_name):
159
+ """Helper to select model and processor."""
160
+ if model_name == "Nanonets-OCR-s":
161
+ return processor_m, model_m
162
+ elif model_name == "MonkeyOCR-Recognition":
163
+ return processor_g, model_g
164
+ elif model_name == "SmolDocling-256M-preview":
165
+ return processor_x, model_x
166
+ elif model_name == "Typhoon-OCR-7B":
167
+ return processor_l, model_l
168
+ elif model_name == "Thyme-RL":
169
+ return processor_n, model_n
170
+ else:
171
+ return None, None
172
+
173
+ def is_video_file(filepath):
174
+ """Check if a file has a common video extension."""
175
+ if not filepath:
176
+ return False
177
+ video_extensions = ['.mp4', '.mov', '.avi', '.mkv', '.webm']
178
+ return any(filepath.lower().endswith(ext) for ext in video_extensions)
179
 
180
  @spaces.GPU
181
+ def generate_response(
182
+ media_file: str,
183
+ query: str,
184
+ model_name: str,
185
+ max_new_tokens: int,
186
+ temperature: float,
187
+ top_p: float
188
+ ):
189
+ """Unified generation function for both image and video."""
190
+ if media_file is None:
191
+ yield "Please upload an image or video file first.", "Please upload an image or video file first."
192
+ return
193
+
194
+ processor, model = get_model_and_processor(model_name)
195
+ if not processor or not model:
196
+ yield "Invalid model selected.", "Invalid model selected."
197
+ return
198
+
199
+ media_type = "video" if is_video_file(media_file) else "image"
200
+
201
+ if media_type == "video":
202
+ frames = downsample_video(media_file)
203
+ images = [frame for frame, _ in frames]
204
+ else: # image
205
+ images = [Image.open(media_file)]
206
+
207
+ if model_name == "SmolDocling-256M-preview":
208
+ if "OTSL" in query or "code" in query:
209
+ images = [add_random_padding(img) for img in images]
210
+ if "OCR at text at" in query or "Identify element" in query or "formula" in query:
211
+ query = normalize_values(query, target_max=500)
212
+
213
+ messages = [
214
+ {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": query}]}
215
+ ]
216
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
217
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
218
+
219
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
220
+ generation_kwargs = {
221
+ **inputs,
222
+ "streamer": streamer,
223
+ "max_new_tokens": max_new_tokens,
224
+ "temperature": temperature,
225
+ "top_p": top_p,
226
+ }
227
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
228
+ thread.start()
229
+
230
+ buffer = ""
231
+ for new_text in streamer:
232
+ buffer += new_text.replace("<|im_end|>", "")
233
+ yield buffer
234
+
235
+ if model_name == "SmolDocling-256M-preview":
236
+ formatted_output = format_smoldocling_output(buffer, images)
237
+ yield formatted_output
238
+ else:
239
+ yield buffer.strip()
240
+
241
+ # --- Gradio Interface ---
242
+
243
+ # --- Examples ---
244
+ image_examples = [
245
+ ["images/0.png", "Reconstruct the doc [table] as it is."],
246
+ ["images/8.png", "Describe the image!"],
247
+ ["images/2.jpg", "OCR the image"],
248
+ ["images/1.png", "Convert this page to docling"],
249
+ ["images/3.png", "Convert this page to docling"],
250
+ ["images/4.png", "Convert chart to OTSL."],
251
+ ["images/5.jpg", "Convert code to text"],
252
+ ["images/6.jpg", "Convert this table to OTSL."],
253
+ ["images/7.jpg", "Convert formula to latex."],
254
+ ]
255
+ video_examples = [
256
+ ["videos/1.mp4", "Explain the video in detail."],
257
+ ["videos/2.mp4", "Explain the video in detail."]
258
+ ]
259
+ all_examples = image_examples + video_examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
 
262
+ # --- UI Styling and Helper Functions ---
263
  css = """
264
+ body, .gradio-container { font-family: 'Inter', sans-serif; }
265
+ .main-container { padding: 20px; }
266
+ .sidebar { background-color: #F7F7F7; border-right: 1px solid #E0E0E0; padding: 15px; border-radius: 15px; }
267
+ .chat-window { min-height: 60vh; border: 1px solid #E0E0E0; border-radius: 15px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.05); }
268
+ .input-bar { padding: 10px; border-radius: 15px; background-color: #FFFFFF; border: 1px solid #E0E0E0; margin-top: 20px;}
269
+ .submit-button { background-color: #007AFF !important; color: white !important; font-weight: bold !important; }
270
+ .media-display {text-align: center; background-color: #F0F0F0; border-radius: 10px; padding: 10px; margin-bottom: 20px;}
271
+ .media-display img, .media-display video {max-height: 400px; margin: auto;}
272
  """
273
 
274
+ def handle_file_upload(file):
275
+ if file is None:
276
+ return None, gr.update(visible=False)
277
+ if is_video_file(file.name):
278
+ return gr.update(value=file.name, visible=False), gr.update(value=file.name, visible=True)
279
+ else:
280
+ return gr.update(value=file.name, visible=True), gr.update(value=file.name, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ def clear_all():
283
+ return None, None, None, ""
284
+
285
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
286
+ # Hidden state to store the path to the uploaded file
287
+ media_file_path = gr.State(None)
288
+
289
+ with gr.Row(elem_classes="main-container"):
290
+ # --- Sidebar ---
291
+ with gr.Column(scale=1, elem_classes="sidebar"):
292
+ gr.Markdown("### OCR Conversations")
293
+ add_conv_btn = gr.Button("+ Add Conversation")
294
+ gr.Markdown("---")
295
+ gr.Markdown("#### Advanced Options")
296
+ with gr.Accordion("⚙️ Generation Settings", open=False):
297
+ max_new_tokens = gr.Slider(
298
+ label="Max New Tokens",
299
+ minimum=256,
300
+ maximum=MAX_MAX_NEW_TOKENS,
301
+ step=64,
302
+ value=DEFAULT_MAX_NEW_TOKENS,
303
+ )
304
+ temperature = gr.Slider(
305
+ label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.6
306
+ )
307
+ top_p = gr.Slider(
308
+ label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9
309
+ )
310
+
311
+ # --- Main Content Panel ---
312
+ with gr.Column(scale=4):
313
+ gr.Markdown("# Multimodal OCR")
314
+
315
+ # --- Media Display Area ---
316
+ with gr.Column(elem_classes="media-display"):
317
+ image_display = gr.Image(type="filepath", label="Image Preview", visible=False)
318
+ video_display = gr.Video(label="Video Preview", visible=False)
319
+ gr.Markdown("Upload an image or video to begin.")
320
+
321
+ # --- Examples ---
322
+ gr.Examples(
323
+ examples=all_examples,
324
+ inputs=[media_file_path, "query_input"],
325
+ label="Examples (Click to run)",
326
+ fn=handle_file_upload, # Custom function to update media display
327
+ outputs=[image_display, video_display]
328
+ )
329
+
330
+ # --- Chat/Output Window ---
331
+ output_display = gr.Markdown(elem_classes="chat-window", value="### Output will be shown here")
332
+
333
+ # --- Input Bar ---
334
+ with gr.Row(elem_classes="input-bar", vertical=False):
335
+ upload_btn = gr.UploadButton("📁 Add Files", file_types=["image", "video"])
336
+ model_dropdown = gr.Dropdown(
337
+ choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"],
338
+ label="Select Model",
339
+ value="Nanonets-OCR-s"
340
+ )
341
+ query_input = gr.Textbox(
342
+ placeholder="Enter your query here...",
343
+ show_label=False,
344
+ scale=4,
345
+ )
346
+ submit_btn = gr.Button("▶", elem_classes="submit-button")
347
+
348
+ # --- Event Handlers ---
349
+ upload_btn.upload(
350
+ fn=handle_file_upload,
351
+ inputs=[upload_btn],
352
+ outputs=[image_display, video_display]
353
  )
354
+
355
+ # When file is uploaded, also store its path in the state
356
+ upload_btn.upload(lambda f: f.name if f else None, upload_btn, media_file_path)
357
+
358
+ submit_btn.click(
359
+ fn=generate_response,
360
+ inputs=[media_file_path, query_input, model_dropdown, max_new_tokens, temperature, top_p],
361
+ outputs=[output_display]
362
  )
363
+
364
+ add_conv_btn.click(
365
+ fn=clear_all,
366
+ outputs=[media_file_path, image_display, video_display, output_display]
367
  )
368
 
369
+
370
  if __name__ == "__main__":
371
+ demo.queue(max_size=50).launch(share=True, show_error=True)