Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import re | |
| import ast | |
| import asyncio | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import cv2 | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoModelForVision2Seq, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from docling_core.types.doc import DoclingDocument, DocTagsDocument | |
| # --- Constants --- | |
| MAX_MAX_NEW_TOKENS = 5120 | |
| DEFAULT_MAX_NEW_TOKENS = 3072 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # --- Model Loading --- | |
| def load_model(model_id, model_class, subfolder=None): | |
| """Generic function to load a model and its processor.""" | |
| processor_kwargs = {"trust_remote_code": True} | |
| model_kwargs = {"trust_remote_code": True, "torch_dtype": torch.float16} | |
| if subfolder: | |
| processor_kwargs["subfolder"] = subfolder | |
| model_kwargs["subfolder"] = subfolder | |
| processor = AutoProcessor.from_pretrained(model_id, **processor_kwargs) | |
| model = model_class.from_pretrained(model_id, **model_kwargs).to(DEVICE).eval() | |
| return processor, model | |
| # Load Nanonets-OCR-s | |
| processor_m, model_m = load_model( | |
| "nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration | |
| ) | |
| # Load MonkeyOCR | |
| processor_g, model_g = load_model( | |
| "echo840/MonkeyOCR", Qwen2_5_VLForConditionalGeneration, subfolder="Recognition" | |
| ) | |
| # Load Typhoon-OCR-7B | |
| processor_l, model_l = load_model( | |
| "scb10x/typhoon-ocr-7b", Qwen2_5_VLForConditionalGeneration | |
| ) | |
| # Load SmolDocling-256M-preview | |
| processor_x, model_x = load_model( | |
| "ds4sd/SmolDocling-256M-preview", AutoModelForVision2Seq | |
| ) | |
| # Thyme-RL | |
| processor_n, model_n = load_model( | |
| "Kwai-Keye/Thyme-RL", Qwen2_5_VLForConditionalGeneration | |
| ) | |
| MODEL_MAPPING = { | |
| "Nanonets-OCR-s": (processor_m, model_m), | |
| "MonkeyOCR-Recognition": (processor_g, model_g), | |
| "Typhoon-OCR-7B": (processor_l, model_l), | |
| "SmolDocling-256M-preview": (processor_x, model_x), | |
| "Thyme-RL": (processor_n, model_n), | |
| } | |
| # --- Preprocessing Functions --- | |
| def add_random_padding(image, min_percent=0.1, max_percent=0.10): | |
| """Add random padding to an image based on its size.""" | |
| image = image.convert("RGB") | |
| width, height = image.size | |
| pad_w_percent = random.uniform(min_percent, max_percent) | |
| pad_h_percent = random.uniform(min_percent, max_percent) | |
| pad_w = int(width * pad_w_percent) | |
| pad_h = int(height * pad_h_percent) | |
| corner_pixel = image.getpixel((0, 0)) | |
| padded_image = ImageOps.expand( | |
| image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel | |
| ) | |
| return padded_image | |
| def normalize_values(text, target_max=500): | |
| """Normalize numerical values in text to a target maximum for SmolDocling.""" | |
| def normalize_list(values): | |
| max_value = max(values) if values else 1 | |
| return [round((v / max_value) * target_max) for v in values] | |
| def process_match(match): | |
| try: | |
| num_list = ast.literal_eval(match.group(0)) | |
| normalized = normalize_list(num_list) | |
| return "".join([f"<loc_{num}>" for num in normalized]) | |
| except (ValueError, SyntaxError): | |
| return match.group(0) | |
| pattern = r"\[([\d\.\s,]+)\]" | |
| return re.sub(pattern, process_match, text) | |
| def downsample_video(video_path, num_frames=10): | |
| """Downsample a video to evenly spaced frames, returning PIL images.""" | |
| if not video_path: | |
| return [] | |
| vidcap = cv2.VideoCapture(video_path) | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frames = [] | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| for i in frame_indices: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| success, image = vidcap.read() | |
| if success: | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(image_rgb)) | |
| vidcap.release() | |
| return frames | |
| # --- Core Generation Logic --- | |
| def _generate_response(model_name, text, images, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
| """Helper function to handle model inference.""" | |
| if not images: | |
| yield "Please upload an image or video.", "" | |
| return | |
| try: | |
| processor, model = MODEL_MAPPING[model_name] | |
| except KeyError: | |
| yield "Invalid model selected.", "" | |
| return | |
| # Model-specific preprocessing | |
| if model_name == "SmolDocling-256M-preview": | |
| if any(keyword in text for keyword in ["OTSL", "code"]): | |
| images = [add_random_padding(img) for img in images] | |
| if any(keyword in text for keyword in ["OCR at text at", "Identify element", "formula"]): | |
| text = normalize_values(text, target_max=500) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "image"}] * len(images) + [{"type": "text", "text": text}], | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=images, return_tensors="pt").to(DEVICE) | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text.replace("<|im_end|>", "") | |
| yield buffer, buffer | |
| # Model-specific post-processing | |
| if model_name == "SmolDocling-256M-preview": | |
| cleaned_output = buffer.replace("<end_of_utterance>", "").strip() | |
| is_doc_tag = any( | |
| tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"] | |
| ) | |
| if is_doc_tag: | |
| if "<chart>" in cleaned_output: | |
| cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>") | |
| cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output) | |
| try: | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| yield buffer, markdown_output | |
| except Exception as e: | |
| yield buffer, f"Error processing Docling output: {e}" | |
| else: | |
| yield buffer, cleaned_output | |
| def generate_for_image(model_name, text, image, *args): | |
| """Generate responses for a single image input.""" | |
| if image is None: | |
| yield "Please upload an image.", "" | |
| return | |
| yield from _generate_response(model_name, text, [image], *args) | |
| def generate_for_video(model_name, text, video_path, *args): | |
| """Generate responses for video input by downsampling frames.""" | |
| if video_path is None: | |
| yield "Please upload a video.", "" | |
| return | |
| frames = downsample_video(video_path) | |
| if not frames: | |
| yield "Could not process video. Please check the file.", "" | |
| return | |
| yield from _generate_response(model_name, text, frames, *args) | |
| # --- Gradio Interface --- | |
| css = """ | |
| .submit-btn { | |
| background-color: #2980b9 !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| border: none !important; | |
| transition: background-color 0.3s ease; | |
| } | |
| .submit-btn:hover { | |
| background-color: #3498db !important; | |
| } | |
| .output-container { | |
| border: 2px solid #4682B4; | |
| border-radius: 10px; | |
| padding: 20px; | |
| height: 100%; | |
| } | |
| """ | |
| # Define examples | |
| image_examples = [ | |
| ["Reconstruct the doc [table] as it is.", "images/0.png"], | |
| ["Describe the image!", "images/8.png"], | |
| ["OCR the image", "images/2.jpg"], | |
| ["Convert this page to docling", "images/1.png"], | |
| ["Convert this page to docling", "images/3.png"], | |
| ["Convert chart to OTSL.", "images/4.png"], | |
| ["Convert code to text", "images/5.jpg"], | |
| ["Convert this table to OTSL.", "images/6.jpg"], | |
| ["Convert formula to latex.", "images/7.jpg"], | |
| ] | |
| video_examples = [ | |
| ["Explain the video in detail.", "videos/1.mp4"], | |
| ["Explain the video in detail.", "videos/2.mp4"], | |
| ] | |
| with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
| gr.Markdown("# **[Multimodal OCR²](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**") | |
| gr.Markdown("A unified interface for state-of-the-art multimodal and document AI models. Select a model, upload an image or video, and enter a query to begin.") | |
| with gr.Row(): | |
| # --- LEFT COLUMN (INPUTS) --- | |
| with gr.Column(scale=1): | |
| model_choice = gr.Radio( | |
| choices=[ | |
| "Nanonets-OCR-s", | |
| "MonkeyOCR-Recognition", | |
| "Thyme-RL", | |
| "Typhoon-OCR-7B", | |
| "SmolDocling-256M-preview", | |
| ], | |
| label="🤖 Select Model", | |
| value="Nanonets-OCR-s", | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("🖼️ Image Inference"): | |
| image_query = gr.Textbox(label="Query", placeholder="e.g., 'OCR the document'") | |
| image_upload = gr.Image(type="pil", label="Upload Image") | |
| image_submit = gr.Button("Generate", elem_classes="submit-btn") | |
| gr.Examples(examples=image_examples, inputs=[image_query, image_upload]) | |
| with gr.TabItem("🎬 Video Inference"): | |
| video_query = gr.Textbox(label="Query", placeholder="e.g., 'What is happening in this video?'") | |
| video_upload = gr.Video(label="Upload Video") | |
| video_submit = gr.Button("Generate", elem_classes="submit-btn") | |
| gr.Examples(examples=video_examples, inputs=[video_query, video_upload]) | |
| with gr.Accordion("⚙️ Advanced Options", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", min=1, max=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", min=0.1, max=2.0, step=0.1, value=0.6 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-P", min=0.05, max=1.0, step=0.05, value=0.9 | |
| ) | |
| top_k = gr.Slider(label="Top-K", min=1, max=1000, step=1, value=50) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", min=1.0, max=2.0, step=0.05, value=1.2 | |
| ) | |
| advanced_params = [max_new_tokens, temperature, top_p, top_k, repetition_penalty] | |
| # --- RIGHT COLUMN (OUTPUTS & INFO) --- | |
| with gr.Column(scale=2): | |
| with gr.Column(elem_classes="output-container"): | |
| gr.Markdown("## Output") | |
| raw_output = gr.Textbox( | |
| label="Raw Output Stream", interactive=False, lines=8 | |
| ) | |
| formatted_output = gr.Markdown(label="Formatted Result (Markdown)") | |
| with gr.Accordion("💻 Model Information", open=True): | |
| gr.Markdown( | |
| """ | |
| - **[Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s)**: Transforms documents into structured markdown with intelligent content recognition. | |
| - **[SmolDocling-256M](https://huggingface.co/ds4sd/SmolDocling-256M-preview)**: An efficient multimodal model for converting documents to structured formats. | |
| - **[MonkeyOCR-Recognition](https://huggingface.co/echo840/MonkeyOCR)**: Adopts a Structure-Recognition-Relation paradigm for efficient document processing. | |
| - **[Typhoon-OCR-7B](https://huggingface.co/scb10x/typhoon-ocr-7b)**: A bilingual (Thai/English) document parsing model for real-world documents. | |
| - **[Thyme-RL](https://huggingface.co/Kwai-Keye/Thyme-RL)**: Generates and executes code for image processing and complex reasoning tasks. | |
| --- | |
| > ⚠️ **Note**: Performance on video inference tasks is experimental and may vary between models. | |
| > [Report a Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/discussions) | |
| """ | |
| ) | |
| # --- Event Handlers --- | |
| image_submit.click( | |
| fn=generate_for_image, | |
| inputs=[model_choice, image_query, image_upload] + advanced_params, | |
| outputs=[raw_output, formatted_output], | |
| ) | |
| video_submit.click( | |
| fn=generate_for_video, | |
| inputs=[model_choice, video_query, video_upload] + advanced_params, | |
| outputs=[raw_output, formatted_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch(share=True, show_error=True) |