File size: 1,729 Bytes
fda320c
ca5d6b4
2ed3b5f
ca5d6b4
 
fda320c
ca5d6b4
 
 
2ed3b5f
ca5d6b4
fda320c
ca5d6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed3b5f
 
 
 
 
 
 
ca5d6b4
2ed3b5f
 
 
 
 
ca5d6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import spaces
from transformers import AutoModel, AutoTokenizer, AutoProcessor
from PIL import Image
import torch

# Load PaddleOCR-VL model
model_name = "PaddlePaddle/PaddleOCR-VL"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

if torch.cuda.is_available():
    model = model.cuda()

@spaces.GPU
def ocr_inference(image):
    """
    Perform OCR on the input image using PaddleOCR-VL
    """
    if image is None:
        return "Please upload an image."
    
    try:
        # Convert to PIL Image if needed
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        # Prepare inputs
        prompt = "Extract all text from this image."
        inputs = processor(images=image, text=prompt, return_tensors="pt")
        
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # Run OCR inference
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=512)
        
        # Decode the output
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return result
    except Exception as e:
        return f"Error during OCR: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=ocr_inference,
    inputs=gr.Image(type="pil", label="Upload Image for OCR"),
    outputs=gr.Textbox(label="Extracted Text"),
    title="PaddleOCR-VL OCR Demo",
    description="Upload an image to extract text using PaddlePaddle/PaddleOCR-VL model"
)

demo.launch()