File size: 2,141 Bytes
a2654bf
71938a5
fdfea87
a2654bf
fdfea87
f4d5db9
 
 
 
 
71938a5
f4d5db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdfea87
a2654bf
71938a5
a2654bf
f4d5db9
 
 
71938a5
f4d5db9
 
71938a5
f4d5db9
 
 
71938a5
f4d5db9
71938a5
f4d5db9
71938a5
f4d5db9
71938a5
f4d5db9
71938a5
 
f4d5db9
 
71938a5
 
f4d5db9
71938a5
a2654bf
f4d5db9
 
 
 
a2654bf
71938a5
 
f4d5db9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from PIL import Image

# Check if we have enough memory, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

@gr.cache_resource
def load_model():
    try:
        print("Loading OLM OCR model...")
        
        # Load with optimizations for limited resources
        processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
        model = AutoModelForVision2Seq.from_pretrained(
            "allenai/olmOCR-2-7B-1025-FP8",
            torch_dtype=torch_dtype,
            device_map="auto" if device == "cuda" else None,
            low_cpu_mem_usage=True
        )
        
        if device == "cpu":
            model = model.to(device)
            
        print("Model loaded successfully!")
        return processor, model
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None

processor, model = load_model()

def extract_text_from_image(image):
    if processor is None or model is None:
        return "Model failed to load. The model might be too large for this environment."
    
    try:
        if image is None:
            return "Please upload an image first."
        
        # Convert and process image
        image = image.convert('RGB')
        inputs = processor(images=image, return_tensors="pt").to(device)
        
        # Generate with optimizations
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=256,  # Reduced for faster processing
                do_sample=False,
                num_beams=1  # Faster but less accurate
            )
        
        text = processor.decode(outputs[0], skip_special_tokens=True)
        return text
        
    except Exception as e:
        return f"Error: {str(e)}"

demo = gr.Interface(
    extract_text_from_image,
    gr.Image(type="pil"),
    gr.Textbox(lines=5),
    title="OLM OCR"
)

if __name__ == "__main__":
    demo.launch()