Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| import os | |
| import base64 | |
| import spaces | |
| import io | |
| from PIL import Image | |
| import numpy as np | |
| import yaml | |
| from pathlib import Path | |
| from globe import title, description, modelinfor, joinus, howto | |
| import uuid | |
| import tempfile | |
| import time | |
| import shutil | |
| import cv2 | |
| import re | |
| import warnings | |
| # Check transformers version for compatibility | |
| try: | |
| import transformers | |
| transformers_version = transformers.__version__ | |
| print(f"Transformers version: {transformers_version}") | |
| # Check if we need to use legacy cache handling | |
| if transformers_version.startswith(('4.4', '4.5', '4.6')): | |
| USE_LEGACY_CACHE = True | |
| else: | |
| USE_LEGACY_CACHE = False | |
| except: | |
| USE_LEGACY_CACHE = False | |
| # Try to import spaces module for ZeroGPU compatibility | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| # Create a dummy decorator for local development | |
| def dummy_gpu_decorator(func): | |
| return func | |
| spaces = type('spaces', (), {'GPU': dummy_gpu_decorator})() | |
| # Suppress specific warnings that are known issues with GOT-OCR | |
| warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set") | |
| warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`") | |
| warnings.filterwarnings("ignore", message="The attention mask is not set and cannot be inferred") | |
| warnings.filterwarnings("ignore", message="The `seen_tokens` attribute is deprecated") | |
| def global_cache_clear(): | |
| """Global cache clearing function to prevent DynamicCache issues""" | |
| try: | |
| # Clear torch cache | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Clear transformers cache | |
| try: | |
| from transformers.cache_utils import clear_cache | |
| clear_cache() | |
| except: | |
| pass | |
| # Clear any DynamicCache instances | |
| try: | |
| from transformers.cache_utils import DynamicCache | |
| if hasattr(DynamicCache, 'clear_all'): | |
| DynamicCache.clear_all() | |
| except: | |
| pass | |
| # Force garbage collection | |
| import gc | |
| gc.collect() | |
| except Exception as e: | |
| print(f"Global cache clear warning: {str(e)}") | |
| pass | |
| class ModelCacheManager: | |
| """ | |
| Manages model cache to prevent DynamicCache errors | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| self._clear_all_caches() | |
| def _clear_all_caches(self): | |
| """Clear all possible caches including DynamicCache""" | |
| # Use global cache clearing first | |
| global_cache_clear() | |
| # Clear model cache | |
| if hasattr(self.model, 'clear_cache'): | |
| try: | |
| self.model.clear_cache() | |
| except: | |
| pass | |
| if hasattr(self.model, '_clear_cache'): | |
| try: | |
| self.model._clear_cache() | |
| except: | |
| pass | |
| # Clear any generation cache | |
| try: | |
| if hasattr(self.model, 'generation_config'): | |
| if hasattr(self.model.generation_config, 'clear_cache'): | |
| self.model.generation_config.clear_cache() | |
| except: | |
| pass | |
| # Clear any cache attributes that might cause DynamicCache issues | |
| cache_attrs = ['cache', '_cache', 'past_key_values', 'use_cache', '_past_key_values'] | |
| for attr in cache_attrs: | |
| if hasattr(self.model, attr): | |
| try: | |
| delattr(self.model, attr) | |
| except: | |
| pass | |
| # Clear transformers cache based on version | |
| try: | |
| if USE_LEGACY_CACHE: | |
| # Legacy cache clearing for older transformers versions | |
| from transformers import GenerationConfig | |
| if hasattr(GenerationConfig, 'clear_cache'): | |
| GenerationConfig.clear_cache() | |
| else: | |
| # New cache clearing for recent transformers versions | |
| try: | |
| from transformers.cache_utils import clear_cache | |
| clear_cache() | |
| except: | |
| pass | |
| # Also try the old method as fallback | |
| try: | |
| from transformers import GenerationConfig | |
| if hasattr(GenerationConfig, 'clear_cache'): | |
| GenerationConfig.clear_cache() | |
| except: | |
| pass | |
| # Try to clear DynamicCache specifically | |
| try: | |
| from transformers.cache_utils import DynamicCache | |
| # Clear any global DynamicCache instances | |
| if hasattr(DynamicCache, 'clear_all'): | |
| DynamicCache.clear_all() | |
| except: | |
| pass | |
| except: | |
| pass | |
| # Clear torch cache | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| # Force garbage collection | |
| try: | |
| import gc | |
| gc.collect() | |
| except: | |
| pass | |
| def safe_call(self, method_name, *args, **kwargs): | |
| """Safely call model methods with cache management""" | |
| try: | |
| # First attempt | |
| method = getattr(self.model, method_name) | |
| return method(*args, **kwargs) | |
| except AttributeError as e: | |
| if "get_max_length" in str(e): | |
| # Clear cache and retry | |
| self._clear_all_caches() | |
| try: | |
| return method(*args, **kwargs) | |
| except: | |
| # Try without any cache-related parameters | |
| kwargs_copy = kwargs.copy() | |
| # Remove any cache-related parameters that might cause issues | |
| for key in list(kwargs_copy.keys()): | |
| if 'cache' in key.lower(): | |
| del kwargs_copy[key] | |
| return method(*args, **kwargs_copy) | |
| else: | |
| raise e | |
| def direct_call(self, method_name, *args, **kwargs): | |
| """Direct call bypassing all cache mechanisms""" | |
| try: | |
| # Clear all caches first | |
| self._clear_all_caches() | |
| # Remove any cache-related parameters | |
| kwargs_copy = kwargs.copy() | |
| for key in list(kwargs_copy.keys()): | |
| if 'cache' in key.lower(): | |
| del kwargs_copy[key] | |
| # Make the call | |
| method = getattr(self.model, method_name) | |
| return method(*args, **kwargs_copy) | |
| except Exception as e: | |
| # If still failing, try the original safe_call as last resort | |
| return self.safe_call(method_name, *args, **kwargs) | |
| def legacy_call(self, method_name, *args, **kwargs): | |
| """Legacy call method for older transformers versions""" | |
| try: | |
| # For legacy versions, we need to handle cache differently | |
| kwargs_copy = kwargs.copy() | |
| # Remove any cache-related parameters | |
| for key in list(kwargs_copy.keys()): | |
| if 'cache' in key.lower(): | |
| del kwargs_copy[key] | |
| # Clear caches | |
| self._clear_all_caches() | |
| # Make the call | |
| method = getattr(self.model, method_name) | |
| return method(*args, **kwargs_copy) | |
| except Exception as e: | |
| # Fallback to direct call | |
| return self.direct_call(method_name, *args, **kwargs) | |
| def dynamic_cache_safe_call(self, method_name, *args, **kwargs): | |
| """Specialized method to handle DynamicCache errors""" | |
| try: | |
| # First, try to completely disable cache mechanisms | |
| original_attrs = {} | |
| # Store and remove cache-related attributes | |
| cache_attrs = ['cache', '_cache', 'past_key_values', 'use_cache', '_past_key_values'] | |
| for attr in cache_attrs: | |
| if hasattr(self.model, attr): | |
| original_attrs[attr] = getattr(self.model, attr) | |
| try: | |
| delattr(self.model, attr) | |
| except: | |
| pass | |
| # Clear all caches | |
| self._clear_all_caches() | |
| # Create minimal kwargs | |
| minimal_kwargs = {} | |
| essential_params = ['ocr_type', 'render', 'save_render_file', 'ocr_box', 'ocr_color'] | |
| for key, value in kwargs.items(): | |
| if key in essential_params and 'cache' not in key.lower(): | |
| minimal_kwargs[key] = value | |
| # Make the call | |
| method = getattr(self.model, method_name) | |
| result = method(*args, **minimal_kwargs) | |
| # Restore original attributes | |
| for attr, value in original_attrs.items(): | |
| try: | |
| setattr(self.model, attr, value) | |
| except: | |
| pass | |
| return result | |
| except AttributeError as e: | |
| if "get_max_length" in str(e) and "DynamicCache" in str(e): | |
| # If DynamicCache error still occurs, try with no parameters | |
| try: | |
| method = getattr(self.model, method_name) | |
| return method(*args) | |
| except Exception as final_error: | |
| raise Exception(f"DynamicCache safe call failed: {str(final_error)}") | |
| else: | |
| raise e | |
| except Exception as e: | |
| raise e | |
| def initialize_model_safely(): | |
| """ | |
| Safely initialize the GOT-OCR model with proper error handling for ZeroGPU | |
| """ | |
| model_name = 'ucaslcl/GOT-OCR2_0' | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| try: | |
| # Initialize tokenizer with proper settings | |
| tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
| # Set pad token properly | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| # Initialize model with proper settings to avoid warnings | |
| model = AutoModel.from_pretrained( | |
| 'ucaslcl/GOT-OCR2_0', | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| device_map=device, | |
| use_safetensors=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| torch_dtype=torch.float16 if device == 'cuda' else torch.float32 | |
| ) | |
| model = model.eval().to(device) | |
| model.config.pad_token_id = tokenizer.eos_token_id | |
| # Ensure the model has proper tokenizer settings | |
| if hasattr(model, 'config'): | |
| model.config.pad_token_id = tokenizer.eos_token_id | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| # Create cache manager | |
| cache_manager = ModelCacheManager(model) | |
| return model, tokenizer, cache_manager | |
| except Exception as e: | |
| print(f"Error initializing model: {str(e)}") | |
| # Fallback initialization | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModel.from_pretrained( | |
| 'ucaslcl/GOT-OCR2_0', | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| device_map=device, | |
| use_safetensors=True | |
| ) | |
| model = model.eval().to(device) | |
| # Create cache manager for fallback model | |
| cache_manager = ModelCacheManager(model) | |
| return model, tokenizer, cache_manager | |
| except Exception as fallback_error: | |
| raise Exception(f"Failed to initialize model: {str(e)}. Fallback also failed: {str(fallback_error)}") | |
| # Initialize model, tokenizer, and cache manager | |
| model, tokenizer, cache_manager = initialize_model_safely() | |
| UPLOAD_FOLDER = "./uploads" | |
| RESULTS_FOLDER = "./results" | |
| for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: | |
| if not os.path.exists(folder): | |
| os.makedirs(folder) | |
| def image_to_base64(image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| def direct_model_call(model, method_name, *args, **kwargs): | |
| """ | |
| Direct model call without any cache-related parameters | |
| """ | |
| # Create a clean kwargs dict without any cache-related parameters | |
| clean_kwargs = {} | |
| for key, value in kwargs.items(): | |
| if 'cache' not in key.lower(): | |
| clean_kwargs[key] = value | |
| # Get the method and call it directly | |
| method = getattr(model, method_name) | |
| return method(*args, **clean_kwargs) | |
| def safe_model_call_with_dynamic_cache_fix(model, method_name, *args, **kwargs): | |
| """ | |
| Comprehensive safe model call that handles DynamicCache errors with multiple fallback strategies | |
| """ | |
| # Strategy 1: Try with complete cache clearing and minimal parameters | |
| try: | |
| # Clear all possible caches first | |
| try: | |
| if hasattr(model, 'clear_cache'): | |
| model.clear_cache() | |
| if hasattr(model, '_clear_cache'): | |
| model._clear_cache() | |
| # Clear transformers cache | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| # Clear any generation cache | |
| try: | |
| if hasattr(model, 'generation_config'): | |
| if hasattr(model.generation_config, 'clear_cache'): | |
| model.generation_config.clear_cache() | |
| except: | |
| pass | |
| except: | |
| pass | |
| # Create minimal kwargs with only essential parameters | |
| minimal_kwargs = {} | |
| essential_params = ['ocr_type', 'render', 'save_render_file', 'ocr_box', 'ocr_color'] | |
| for key, value in kwargs.items(): | |
| if key in essential_params and 'cache' not in key.lower(): | |
| minimal_kwargs[key] = value | |
| method = getattr(model, method_name) | |
| return method(*args, **minimal_kwargs) | |
| except AttributeError as e: | |
| if "get_max_length" in str(e) and "DynamicCache" in str(e): | |
| print("DynamicCache error detected, applying comprehensive workaround...") | |
| # Strategy 2: Try with model cache manager | |
| try: | |
| return cache_manager.direct_call(method_name, *args, **kwargs) | |
| except Exception as cache_error: | |
| print(f"Cache manager failed: {str(cache_error)}") | |
| # Strategy 3: Try with legacy cache handling | |
| try: | |
| return cache_manager.legacy_call(method_name, *args, **kwargs) | |
| except Exception as legacy_error: | |
| print(f"Legacy cache handling failed: {str(legacy_error)}") | |
| # Strategy 4: Try with completely stripped parameters | |
| try: | |
| # Remove ALL parameters except the most basic ones | |
| stripped_kwargs = {} | |
| if 'ocr_type' in kwargs: | |
| stripped_kwargs['ocr_type'] = kwargs['ocr_type'] | |
| method = getattr(model, method_name) | |
| return method(*args, **stripped_kwargs) | |
| except Exception as stripped_error: | |
| print(f"Stripped parameters failed: {str(stripped_error)}") | |
| # Strategy 5: Try with monkey patching to bypass cache | |
| try: | |
| # Temporarily disable cache-related attributes | |
| original_attrs = {} | |
| # Store original attributes that might cause issues | |
| for attr_name in ['cache', '_cache', 'past_key_values', 'use_cache']: | |
| if hasattr(model, attr_name): | |
| original_attrs[attr_name] = getattr(model, attr_name) | |
| try: | |
| delattr(model, attr_name) | |
| except: | |
| pass | |
| # Try the call | |
| method = getattr(model, method_name) | |
| result = method(*args, **stripped_kwargs) | |
| # Restore original attributes | |
| for attr_name, value in original_attrs.items(): | |
| try: | |
| setattr(model, attr_name, value) | |
| except: | |
| pass | |
| return result | |
| except Exception as monkey_error: | |
| print(f"Monkey patching failed: {str(monkey_error)}") | |
| # Strategy 6: Final fallback - try with no parameters at all | |
| try: | |
| method = getattr(model, method_name) | |
| return method(*args) | |
| except Exception as final_error: | |
| raise Exception(f"All DynamicCache workarounds failed. Last error: {str(final_error)}") | |
| else: | |
| # Re-raise if it's not the DynamicCache error | |
| raise e | |
| except Exception as e: | |
| # Handle other errors | |
| raise e | |
| def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None): | |
| """ | |
| Process image with OCR using ZeroGPU-compatible approach | |
| """ | |
| # Clear global cache at the start to prevent DynamicCache issues | |
| global_cache_clear() | |
| if image is None: | |
| return "Error: No image provided", None, None | |
| unique_id = str(uuid.uuid4()) | |
| image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png") | |
| result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html") | |
| try: | |
| if isinstance(image, dict): | |
| composite_image = image.get("composite") | |
| if composite_image is not None: | |
| if isinstance(composite_image, np.ndarray): | |
| cv2.imwrite(image_path, cv2.cvtColor(composite_image, cv2.COLOR_RGB2BGR)) | |
| elif isinstance(composite_image, Image.Image): | |
| composite_image.save(image_path) | |
| else: | |
| return "Error: Unsupported image format from ImageEditor", None, None | |
| else: | |
| return "Error: No composite image found in ImageEditor output", None, None | |
| elif isinstance(image, np.ndarray): | |
| cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| elif isinstance(image, str): | |
| shutil.copy(image, image_path) | |
| else: | |
| return "Error: Unsupported image format", None, None | |
| # Use specialized DynamicCache-safe model calls | |
| try: | |
| if task == "Plain Text OCR": | |
| res = cache_manager.dynamic_cache_safe_call('chat', tokenizer, image_path, ocr_type='ocr') | |
| return res, None, unique_id | |
| else: | |
| if task == "Format Text OCR": | |
| res = cache_manager.dynamic_cache_safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Box)": | |
| res = cache_manager.dynamic_cache_safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Color)": | |
| res = cache_manager.dynamic_cache_safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
| elif task == "Multi-crop OCR": | |
| res = cache_manager.dynamic_cache_safe_call('chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Render Formatted OCR": | |
| res = cache_manager.dynamic_cache_safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| if os.path.exists(result_path): | |
| with open(result_path, 'r') as f: | |
| html_content = f.read() | |
| return res, html_content, unique_id | |
| else: | |
| return res, None, unique_id | |
| except Exception as e: | |
| # If dynamic cache safe call fails, try with comprehensive workaround | |
| try: | |
| if task == "Plain Text OCR": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat', tokenizer, image_path, ocr_type='ocr') | |
| return res, None, unique_id | |
| else: | |
| if task == "Format Text OCR": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Box)": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Color)": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
| elif task == "Multi-crop OCR": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Render Formatted OCR": | |
| res = safe_model_call_with_dynamic_cache_fix(model, 'chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| if os.path.exists(result_path): | |
| with open(result_path, 'r') as f: | |
| html_content = f.read() | |
| return res, html_content, unique_id | |
| else: | |
| return res, None, unique_id | |
| except Exception as fallback_error: | |
| # Final fallback to basic cache manager | |
| try: | |
| if task == "Plain Text OCR": | |
| res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='ocr') | |
| return res, None, unique_id | |
| else: | |
| if task == "Format Text OCR": | |
| res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Box)": | |
| res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path) | |
| elif task == "Fine-grained OCR (Color)": | |
| res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
| elif task == "Multi-crop OCR": | |
| res = cache_manager.safe_call('chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| elif task == "Render Formatted OCR": | |
| res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) | |
| if os.path.exists(result_path): | |
| with open(result_path, 'r') as f: | |
| html_content = f.read() | |
| return res, html_content, unique_id | |
| else: | |
| return res, None, unique_id | |
| except Exception as final_error: | |
| return f"Error: {str(final_error)}", None, None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None | |
| finally: | |
| if os.path.exists(image_path): | |
| os.remove(image_path) | |
| def update_image_input(task): | |
| if task == "Fine-grained OCR (Color)": | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| def update_inputs(task): | |
| if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]: | |
| return [ | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ] | |
| elif task == "Fine-grained OCR (Box)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ] | |
| elif task == "Fine-grained OCR (Color)": | |
| return [ | |
| gr.update(visible=True, choices=["ocr", "format"]), | |
| gr.update(visible=False), | |
| gr.update(visible=True, choices=["red", "green", "blue"]), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ] | |
| def parse_latex_output(res): | |
| # Split the input, preserving newlines and empty lines | |
| lines = re.split(r'(\$\$.*?\$\$)', res, flags=re.DOTALL) | |
| parsed_lines = [] | |
| in_latex = False | |
| latex_buffer = [] | |
| for line in lines: | |
| if line == '\n': | |
| if in_latex: | |
| latex_buffer.append(line) | |
| else: | |
| parsed_lines.append(line) | |
| continue | |
| line = line.strip() | |
| latex_patterns = [r'\{', r'\}', r'\[', r'\]', r'\\', r'\$', r'_', r'^', r'"'] | |
| contains_latex = any(re.search(pattern, line) for pattern in latex_patterns) | |
| if contains_latex: | |
| if not in_latex: | |
| in_latex = True | |
| latex_buffer = ['$$'] | |
| latex_buffer.append(line) | |
| else: | |
| if in_latex: | |
| latex_buffer.append('$$') | |
| parsed_lines.extend(latex_buffer) | |
| in_latex = False | |
| latex_buffer = [] | |
| parsed_lines.append(line) | |
| if in_latex: | |
| latex_buffer.append('$$') | |
| parsed_lines.extend(latex_buffer) | |
| return '$$\\$$\n'.join(parsed_lines) | |
| def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): | |
| """ | |
| Main OCR demonstration function that processes images and returns results. | |
| Args: | |
| image (Union[dict, np.ndarray, str, PIL.Image]): Input image in one of these formats: Image component state with keys: path: str | None (Path to local file) url: str | None (Public URL or base64 image) size: int | None (Image size in bytes) orig_name: str | None (Original filename) mime_type: str | None (Image MIME type) is_stream: bool (Always False) meta: dict(str, Any) OR dict: ImageEditor component state with keys: background: filepath | None layers: list[filepath] composite: filepath | None id: str | None OR np.ndarray: Raw image array str: Path to image file PIL.Image: PIL Image object | |
| task (Literal['Plain Text OCR', 'Format Text OCR', 'Fine-grained OCR (Box)', 'Fine-grained OCR (Color)', 'Multi-crop OCR', 'Render Formatted OCR'], default: "Plain Text OCR"): The type of OCR processing to perform: "Plain Text OCR": Basic text extraction without formatting, "Format Text OCR": Text extraction with preserved formatting, "Fine-grained OCR (Box)": Text extraction from specific bounding box regions, "Fine-grained OCR (Color)": Text extraction from regions marked with specific colors, "Multi-crop OCR": Text extraction from multiple cropped regions, "Render Formatted OCR": Text extraction with HTML rendering of formatting | |
| ocr_type (Literal['ocr', 'format'], default: "ocr"):The type of OCR processing to apply: "ocr": Basic text extraction without formatting "format": Text extraction with preserved formatting and structure | |
| ocr_box (str): Bounding box coordinates specifying the region for fine-grained OCR. Format: "x1,y1,x2,y2" where: x1,y1: Top-left corner coordinates ; x2,y2: Bottom-right corner coordinates Example: "100,100,300,200" for a box starting at (100,100) and ending at (300,200) | |
| ocr_color (Literal['red', 'green', 'blue'], default: "red"): Color specification for fine-grained OCR when using color-based region selection: "red": Extract text from regions marked in red "green": Extract text from regions marked in green "blue": Extract text from regions marked in blue | |
| Returns: | |
| tuple: (formatted_result, html_output) | |
| - formatted_result (str): Formatted OCR result text | |
| - html_output (str): HTML visualization if applicable | |
| """ | |
| res, html_content, unique_id = process_image(image, task, ocr_type, ocr_box, ocr_color) | |
| if isinstance(res, str) and res.startswith("Error:"): | |
| return res, None | |
| res = res.replace("\\title", "\\title ") | |
| formatted_res = res | |
| # formatted_res = parse_latex_output(res) | |
| if html_content: | |
| encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8') | |
| iframe_src = f"data:text/html;base64,{encoded_html}" | |
| iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' | |
| download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>' | |
| return formatted_res, f"{download_link}<br>{iframe}" | |
| return formatted_res, None | |
| def cleanup_old_files(): | |
| current_time = time.time() | |
| for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: | |
| for file_path in Path(folder).glob('*'): | |
| if current_time - file_path.stat().st_mtime > 3600: # 1 hour | |
| file_path.unlink() | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| with gr.Row(): | |
| gr.Markdown(title) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(description) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(modelinfor) | |
| gr.Markdown(joinus) | |
| with gr.Row(): | |
| with gr.Accordion("How to use Fine-grained OCR (Color)", open=False): | |
| with gr.Row(): | |
| gr.Image("res/image/howto_1.png", label="Select the Following Parameters") | |
| gr.Image("res/image/howto_2.png", label="Click on Paintbrush in the Image Editor") | |
| gr.Image("res/image/howto_3.png", label="Select your Brush Color (Red)") | |
| gr.Image("res/image/howto_4.png", label="Make a Box Around The Text") | |
| with gr.Row(): | |
| with gr.Group(): | |
| gr.Markdown(howto) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| image_input = gr.Image(type="filepath", label="Input Image") | |
| image_editor = gr.ImageEditor(label="Image Editor", type="pil", visible=False) | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Plain Text OCR", | |
| "Format Text OCR", | |
| "Fine-grained OCR (Box)", | |
| "Fine-grained OCR (Color)", | |
| "Multi-crop OCR", | |
| "Render Formatted OCR" | |
| ], | |
| label="Select Task", | |
| value="Plain Text OCR" | |
| ) | |
| ocr_type_dropdown = gr.Dropdown( | |
| choices=["ocr", "format"], | |
| label="OCR Type", | |
| visible=False | |
| ) | |
| ocr_box_input = gr.Textbox( | |
| label="OCR Box (x1,y1,x2,y2)", | |
| placeholder="[100,100,200,200]", | |
| visible=False | |
| ) | |
| ocr_color_dropdown = gr.Dropdown( | |
| choices=["red", "green", "blue"], | |
| label="OCR Color", | |
| visible=False | |
| ) | |
| # with gr.Row(): | |
| # max_new_tokens_slider = gr.Slider(50, 500, step=10, value=150, label="Max New Tokens") | |
| # no_repeat_ngram_size_slider = gr.Slider(1, 10, step=1, value=2, label="No Repeat N-gram Size") | |
| submit_button = gr.Button("Process") | |
| editor_submit_button = gr.Button("Process Edited Image", visible=False) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| output_markdown = gr.Textbox(label="🫴🏻📸GOT-OCR") | |
| output_html = gr.HTML(label="🫴🏻📸GOT-OCR") | |
| task_dropdown.change( | |
| update_inputs, | |
| inputs=[task_dropdown], | |
| outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, image_input, image_editor, submit_button, editor_submit_button] | |
| ) | |
| task_dropdown.change( | |
| update_image_input, | |
| inputs=[task_dropdown], | |
| outputs=[image_input, image_editor, editor_submit_button] | |
| ) | |
| submit_button.click( | |
| ocr_demo, | |
| inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
| outputs=[output_markdown, output_html] | |
| ) | |
| editor_submit_button.click( | |
| ocr_demo, | |
| inputs=[image_editor, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
| outputs=[output_markdown, output_html] | |
| ) | |
| if __name__ == "__main__": | |
| cleanup_old_files() | |
| demo.launch(ssr_mode = False, mcp_server=True) |