Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import uuid | |
| import json | |
| import time | |
| import asyncio | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import cv2 | |
| from transformers import ( | |
| Qwen2VLForConditionalGeneration, | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoModelForCausalLM, | |
| AutoModelForVision2Seq, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from transformers.image_utils import load_image | |
| # These imports seem to be from a custom library. | |
| # If you have 'docling_core' installed, you can uncomment them. | |
| # from docling_core.types.doc import DoclingDocument, DocTagsDocument | |
| import re | |
| import ast | |
| import html | |
| # --- Constants --- | |
| MAX_MAX_NEW_TOKENS = 5120 | |
| DEFAULT_MAX_NEW_TOKENS = 3072 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # --- Model Loading --- | |
| # Load Nanonets-OCR-s | |
| MODEL_ID_M = "nanonets/Nanonets-OCR-s" | |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_M, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| # Load MonkeyOCR | |
| MODEL_ID_G = "echo840/MonkeyOCR" | |
| SUBFOLDER = "Recognition" | |
| processor_g = AutoProcessor.from_pretrained( | |
| MODEL_ID_G, | |
| trust_remote_code=True, | |
| subfolder=SUBFOLDER | |
| ) | |
| model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_G, | |
| trust_remote_code=True, | |
| subfolder=SUBFOLDER, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| # Load Typhoon-OCR-7B | |
| MODEL_ID_L = "scb10x/typhoon-ocr-7b" | |
| processor_l = AutoProcessor.from_pretrained(MODEL_ID_L, trust_remote_code=True) | |
| model_l = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_L, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| # Load SmolDocling-256M-preview | |
| MODEL_ID_X = "ds4sd/SmolDocling-256M-preview" | |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
| model_x = AutoModelForVision2Seq.from_pretrained( | |
| MODEL_ID_X, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| # Thyme-RL | |
| MODEL_ID_N = "Kwai-Keye/Thyme-RL" | |
| processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True) | |
| model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_N, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| # --- Preprocessing and Helper Functions --- | |
| def add_random_padding(image, min_percent=0.1, max_percent=0.10): | |
| """Add random padding to an image based on its size.""" | |
| image = image.convert("RGB") | |
| width, height = image.size | |
| pad_w_percent = random.uniform(min_percent, max_percent) | |
| pad_h_percent = random.uniform(min_percent, max_percent) | |
| pad_w = int(width * pad_w_percent) | |
| pad_h = int(height * pad_h_percent) | |
| corner_pixel = image.getpixel((0, 0)) # Top-left corner | |
| padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) | |
| return padded_image | |
| def normalize_values(text, target_max=500): | |
| """Normalize numerical values in text to a target maximum.""" | |
| def normalize_list(values): | |
| max_value = max(values) if values else 1 | |
| return [round((v / max_value) * target_max) for v in values] | |
| def process_match(match): | |
| num_list = ast.literal_eval(match.group(0)) | |
| normalized = normalize_list(num_list) | |
| return "".join([f"<loc_{num}>" for num in normalized]) | |
| pattern = r"\[([\d\.\s,]+)\]" | |
| normalized_text = re.sub(pattern, process_match, text) | |
| return normalized_text | |
| def downsample_video(video_path): | |
| """Downsample a video to evenly spaced frames, returning PIL images with timestamps.""" | |
| vidcap = cv2.VideoCapture(video_path) | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| frames = [] | |
| # Use 10 frames for video processing | |
| frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) | |
| for i in frame_indices: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| success, image = vidcap.read() | |
| if success: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(image) | |
| timestamp = round(i / fps, 2) | |
| frames.append((pil_image, timestamp)) | |
| vidcap.release() | |
| return frames | |
| # A placeholder function in case docling_core is not installed | |
| def format_smoldocling_output(buffer_text, images): | |
| cleaned_output = buffer_text.replace("<end_of_utterance>", "").strip() | |
| # Check if docling_core is available and was imported | |
| if 'DocTagsDocument' in globals() and 'DoclingDocument' in globals(): | |
| if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]): | |
| if "<chart>" in cleaned_output: | |
| cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>") | |
| cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output) | |
| doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images) | |
| doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
| markdown_output = doc.export_to_markdown() | |
| return markdown_output | |
| # Fallback if library is not available or tags are not present | |
| return cleaned_output | |
| # --- Core Generation Logic --- | |
| def get_model_and_processor(model_name): | |
| """Helper to select model and processor.""" | |
| if model_name == "Nanonets-OCR-s": | |
| return processor_m, model_m | |
| elif model_name == "MonkeyOCR-Recognition": | |
| return processor_g, model_g | |
| elif model_name == "SmolDocling-256M-preview": | |
| return processor_x, model_x | |
| elif model_name == "Typhoon-OCR-7B": | |
| return processor_l, model_l | |
| elif model_name == "Thyme-RL": | |
| return processor_n, model_n | |
| else: | |
| return None, None | |
| def is_video_file(filepath): | |
| """Check if a file has a common video extension.""" | |
| if not filepath: | |
| return False | |
| video_extensions = ['.mp4', '.mov', '.avi', '.mkv', '.webm'] | |
| return any(filepath.lower().endswith(ext) for ext in video_extensions) | |
| def generate_response( | |
| media_file: str, | |
| query: str, | |
| model_name: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float | |
| ): | |
| """Unified generation function for both image and video.""" | |
| if media_file is None: | |
| yield "Please upload an image or video file first." | |
| return | |
| processor, model = get_model_and_processor(model_name) | |
| if not processor or not model: | |
| yield "Invalid model selected." | |
| return | |
| media_type = "video" if is_video_file(media_file) else "image" | |
| try: | |
| if media_type == "video": | |
| frames = downsample_video(media_file) | |
| images = [frame for frame, _ in frames] | |
| else: # image | |
| images = [Image.open(media_file)] | |
| except Exception as e: | |
| yield f"Error processing file: {e}" | |
| return | |
| if model_name == "SmolDocling-256M-preview": | |
| if "OTSL" in query or "code" in query: | |
| images = [add_random_padding(img) for img in images] | |
| if "OCR at text at" in query or "Identify element" in query or "formula" in query: | |
| query = normalize_values(query, target_max=500) | |
| messages = [ | |
| {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": query}]} | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text.replace("<|im_end|>", "") | |
| yield buffer | |
| if model_name == "SmolDocling-256M-preview": | |
| formatted_output = format_smoldocling_output(buffer, images) | |
| yield formatted_output | |
| else: | |
| yield buffer.strip() | |
| # --- Gradio Interface --- | |
| # --- Examples --- | |
| image_examples = [ | |
| ["images/0.png", "Reconstruct the doc [table] as it is."], | |
| ["images/8.png", "Describe the image!"], | |
| ["images/2.jpg", "OCR the image"], | |
| ["images/1.png", "Convert this page to docling"], | |
| ["images/3.png", "Convert this page to docling"], | |
| ["images/4.png", "Convert chart to OTSL."], | |
| ["images/5.jpg", "Convert code to text"], | |
| ["images/6.jpg", "Convert this table to OTSL."], | |
| ["images/7.jpg", "Convert formula to latex."], | |
| ] | |
| video_examples = [ | |
| ["videos/1.mp4", "Explain the video in detail."], | |
| ["videos/2.mp4", "Explain the video in detail."] | |
| ] | |
| all_examples = image_examples + video_examples | |
| # --- UI Styling and Helper Functions --- | |
| css = """ | |
| body, .gradio-container { font-family: 'Inter', sans-serif; } | |
| .main-container { padding: 20px; } | |
| .sidebar { background-color: #F7F7F7; border-right: 1px solid #E0E0E0; padding: 15px; border-radius: 15px; } | |
| .chat-window { min-height: 60vh; border: 1px solid #E0E0E0; border-radius: 15px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.05); } | |
| .input-bar { padding: 10px; border-radius: 15px; background-color: #FFFFFF; border: 1px solid #E0E0E0; margin-top: 20px;} | |
| .submit-button { background-color: #007AFF !important; color: white !important; font-weight: bold !important; } | |
| .media-display {text-align: center; background-color: #F0F0F0; border-radius: 10px; padding: 10px; margin-bottom: 20px;} | |
| .media-display img, .media-display video {max-height: 400px; margin: auto;} | |
| """ | |
| def handle_file_upload(file): | |
| if file is None: | |
| return None, gr.update(visible=False), gr.update(visible=False) | |
| if is_video_file(file.name): | |
| return file.name, gr.update(visible=False), gr.update(value=file.name, visible=True) | |
| else: | |
| return file.name, gr.update(value=file.name, visible=True), gr.update(visible=False) | |
| def handle_example_click(file_path, query): | |
| if is_video_file(file_path): | |
| # Update state, hide image, show video, update query | |
| return file_path, gr.update(visible=False), gr.update(value=file_path, visible=True), query | |
| else: | |
| # Update state, show image, hide video, update query | |
| return file_path, gr.update(value=file_path, visible=True), gr.update(visible=False), query | |
| def clear_all(): | |
| return None, None, None, "### Output will be shown here", "" | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| media_file_path = gr.State(None) | |
| with gr.Row(elem_classes="main-container"): | |
| # --- Sidebar --- | |
| with gr.Column(scale=1, elem_classes="sidebar"): | |
| gr.Markdown("### OCR Conversations") | |
| add_conv_btn = gr.Button("+ Add Conversation") | |
| gr.Markdown("---") | |
| gr.Markdown("#### Advanced Options") | |
| with gr.Accordion("⚙️ Generation Settings", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", | |
| minimum=256, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=64, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.6 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9 | |
| ) | |
| # --- Main Content Panel --- | |
| with gr.Column(scale=4): | |
| gr.Markdown("# Multimodal OCR") | |
| with gr.Column(elem_classes="media-display"): | |
| image_display = gr.Image(type="filepath", label="Image Preview", visible=False) | |
| video_display = gr.Video(label="Video Preview", visible=False) | |
| gr.Markdown("Upload an image or video to begin.") | |
| # Define query_input here so gr.Examples can reference it | |
| query_input = gr.Textbox( | |
| placeholder="Enter your query here...", | |
| show_label=False, | |
| scale=4, | |
| ) | |
| gr.Examples( | |
| examples=all_examples, | |
| inputs=[media_file_path, query_input], # Pass component objects | |
| outputs=[media_file_path, image_display, video_display, query_input], | |
| fn=handle_example_click, | |
| label="Examples (Click to run)", | |
| cache_examples=True | |
| ) | |
| output_display = gr.Markdown(elem_classes="chat-window", value="### Output will be shown here") | |
| with gr.Row(elem_classes="input-bar", vertical=False): | |
| upload_btn = gr.UploadButton("📁 Add Files", file_types=["image", "video"]) | |
| model_dropdown = gr.Dropdown( | |
| choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview"], | |
| label="Select Model", | |
| value="Nanonets-OCR-s" | |
| ) | |
| # The query_input is already defined above, but we place it here visually | |
| submit_btn = gr.Button("▶", elem_classes="submit-button") | |
| # --- Event Handlers --- | |
| upload_btn.upload( | |
| fn=handle_file_upload, | |
| inputs=[upload_btn], | |
| outputs=[media_file_path, image_display, video_display] | |
| ) | |
| submit_btn.click( | |
| fn=generate_response, | |
| inputs=[media_file_path, query_input, model_dropdown, max_new_tokens, temperature, top_p], | |
| outputs=[output_display] | |
| ) | |
| add_conv_btn.click( | |
| fn=clear_all, | |
| outputs=[media_file_path, image_display, video_display, output_display, query_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch(share=True, show_error=True) |