import os import random import uuid import json import time import asyncio from threading import Thread from typing import Iterable import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps import cv2 import requests from transformers import ( Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, TextIteratorStreamer, ) from transformers.image_utils import load_image from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes from docling_core.types.doc import DoclingDocument, DocTagsDocument import re import ast import html # --- Theme and CSS Definition --- colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", # SteelBlue base color c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() css = """ #main-title h1 { font-size: 2.3em !important; } #output-title h2 { font-size: 2.1em !important; } """ # Constants for text generation MAX_MAX_NEW_TOKENS = 4096 DEFAULT_MAX_NEW_TOKENS = 2048 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load Nanonets-OCR-s MODEL_ID_M = "nanonets/Nanonets-OCR-s" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load MonkeyOCR MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # Load Typhoon-OCR-7B MODEL_ID_L = "scb10x/typhoon-ocr-7b" processor_l = AutoProcessor.from_pretrained(MODEL_ID_L, trust_remote_code=True) model_l = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_L, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load SmolDocling-256M-preview MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = AutoModelForVision2Seq.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Thyme-RL MODEL_ID_N = "Kwai-Keye/Thyme-RL" processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True) model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_N, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Preprocessing functions for SmolDocling-256M def add_random_padding(image, min_percent=0.1, max_percent=0.10): """Add random padding to an image based on its size.""" image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) # Top-left corner padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image def normalize_values(text, target_max=500): """Normalize numerical values in text to a target maximum.""" def normalize_list(values): max_value = max(values) if values else 1 return [round((v / max_value) * target_max) for v in values] def process_match(match): num_list = ast.literal_eval(match.group(0)) normalized = normalize_list(num_list) return "".join([f"" for num in normalized]) pattern = r"\[([\d\.\s,]+)\]" normalized_text = re.sub(pattern, process_match, text) return normalized_text def downsample_video(video_path): """Downsample a video to evenly spaced frames, returning PIL images with timestamps.""" vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for image input using the selected model.""" if model_name == "Nanonets-OCR-s": processor, model = processor_m, model_m elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g elif model_name == "SmolDocling-256M-preview": processor, model = processor_x, model_x elif model_name == "Typhoon-OCR-7B": processor, model = processor_l, model_l elif model_name == "Thyme-RL": processor, model = processor_n, model_n else: yield "Invalid model selected.", "Invalid model selected." return if image is None: yield "Please upload an image.", "Please upload an image." return images = [image] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text.replace("<|im_end|>", "") yield buffer, buffer if model_name == "SmolDocling-256M-preview": cleaned_output = buffer.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield buffer, markdown_output else: yield buffer, cleaned_output @spaces.GPU def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for video input using the selected model.""" if model_name == "Nanonets-OCR-s": processor, model = processor_m, model_m elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g elif model_name == "SmolDocling-256M-preview": processor, model = processor_x, model_x elif model_name == "Typhoon-OCR-7B": processor, model = processor_l, model_l elif model_name == "Thyme-RL": processor, model = processor_n, model_n else: yield "Invalid model selected.", "Invalid model selected." return if video_path is None: yield "Please upload a video.", "Please upload a video." return frames = downsample_video(video_path) images = [frame for frame, _ in frames] if model_name == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] if "OCR at text at" in text or "Identify element" in text or "formula" in text: text = normalize_values(text, target_max=500) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text.replace("<|im_end|>", "") yield buffer, buffer if model_name == "SmolDocling-256M-preview": cleaned_output = buffer.replace("", "").strip() if any(tag in cleaned_output for tag in ["", "", "", "", ""]): if "" in cleaned_output: cleaned_output = cleaned_output.replace("", "").replace("", "") cleaned_output = re.sub(r'()(?!.*)<[^>]+>', r'\1', cleaned_output) doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") markdown_output = doc.export_to_markdown() yield buffer, markdown_output else: yield buffer, cleaned_output # Define examples for image and video inference image_examples = [ ["Reconstruct the doc [table] as it is.", "images/0.png"], ["Describe the image!", "images/8.png"], ["OCR the image", "images/2.jpg"], ["Convert this page to docling", "images/1.png"], ["Convert this page to docling", "images/3.png"], ] video_examples = [ ["Explain the video in detail.", "videos/1.mp4"], ["Explain the video in detail.", "videos/2.mp4"] ] # Create the Gradio Interface with gr.Blocks(css=css, theme=steel_blue_theme) as demo: gr.Markdown("# **Multimodal OCR2**", elem_id="main-title") with gr.Row(): with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("Image Inference"): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Upload Image", height=290) image_submit = gr.Button("Submit", variant="primary") gr.Examples(examples=image_examples, inputs=[image_query, image_upload]) with gr.TabItem("Video Inference"): video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") video_upload = gr.Video(label="Upload Video (<= 30s)", height=290) video_submit = gr.Button("Submit", variant="primary") gr.Examples(examples=video_examples, inputs=[video_query, video_upload]) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) with gr.Column(scale=3): gr.Markdown("## Output", elem_id="output-title") raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True) with gr.Accordion("(Result.md)", open=False): formatted_output = gr.Markdown(label="(Result.md)") model_choice = gr.Radio( choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"], label="Select Model", value="Nanonets-OCR-s" ) image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output, formatted_output] ) video_submit.click( fn=generate_video, inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output, formatted_output] ) if __name__ == "__main__": demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)