File size: 6,733 Bytes
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c036b
 
 
 
 
 
 
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa5f4d6
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
from typing import Tuple

import gradio as gr
from PIL import Image

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    VisionEncoderDecoderModel,
    TrOCRProcessor,
)

from huggingface_hub import login
import os

hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token)


TITLE = "Picture to Problem Solver"
DESCRIPTION = (
    "Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question."
    "\n\n⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
)

# ---------------------------
# Load OCR (TrOCR)
# ---------------------------
# Use the "printed" variant for typed/scanned text.
# If you expect handwriting, switch to: microsoft/trocr-base-handwritten
OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
ocr_model.eval()

# ---------------------------
# Load MobileLLM
# ---------------------------
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")

# Device & dtype selection that plays nice on Spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# Keep dtype conservative to avoid OOM on CPU Spaces
torch_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
llm_model = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_ID,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    device_map="auto" if device == "cuda" else None,
)
llm_model.eval()
if device == "cpu":
    llm_model.to(device)

# Ensure EOS/BOS tokens exist
eos_token_id = llm_tokenizer.eos_token_id
if eos_token_id is None:
    # Fallback: add one if truly missing (rare)
    llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
    llm_model.resize_token_embeddings(len(llm_tokenizer))
    eos_token_id = llm_tokenizer.eos_token_id


SYSTEM_INSTRUCTION = (
    "You are a precise, step-by-step technical assistant. "
    "You excel at math, programming (Python, C++), and scientific reasoning. "
    "Be concise, show steps when helpful, and avoid hallucinations. "
)

USER_PROMPT_TEMPLATE = (
    "Extracted text from the image:\n"
    "-----------------------------\n"
    "{ocr_text}\n"
    "-----------------------------\n"
    "{question_hint}"
)

def build_prompt(ocr_text: str, user_question: str) -> str:
    if user_question and user_question.strip():
        q = f"User question: {user_question.strip()}"
    else:
        q = "Please summarize the key information and explain any math/code/science content."

    return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
        ocr_text=ocr_text.strip() if ocr_text else "(no text detected)",
        question_hint=q,
    )


@torch.inference_mode()
def run_pipeline(
    image: Image.Image,
    question: str,
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    top_p: float = 0.9,
) -> Tuple[str, str]:
    """
    Returns:
        (extracted_text, model_answer)
    """
    if image is None:
        return "", "Please upload an image."

    # --- OCR ---
    # TrOCR wants pixel_values prepared by its processor
    pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
    with torch.inference_mode():
        ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
    extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()

    # --- Build prompt for LLM ---
    prompt = build_prompt(extracted_text, question)

    # --- LLM Inference ---
    inputs = llm_tokenizer(prompt, return_tensors="pt")
    if device == "cuda":
        inputs = {k: v.to(llm_model.device) for k, v in inputs.items()}
    else:
        inputs = {k: v.to(device) for k, v in inputs.items()}

    generation_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=True if temperature > 0 else False,
        temperature=max(0.0, min(temperature, 1.5)),
        top_p=max(0.1, min(top_p, 1.0)),
        eos_token_id=eos_token_id,
        pad_token_id=llm_tokenizer.eos_token_id,  # keep decoding clean
    )

    output_ids = llm_model.generate(**inputs, **generation_kwargs)
    # We only want the newly generated part for readability
    gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Optional: strip the original prompt if the model echoes it
    if gen_text.startswith(prompt):
        gen_text = gen_text[len(prompt):].lstrip()

    return extracted_text, gen_text



def demo_ui():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(f"# {TITLE}")
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(type="pil", label="Upload an image")
                question = gr.Textbox(
                    label="Ask a question about the image (optional)",
                    placeholder="e.g., Summarize, extract key numbers, explain this formula, write Python to do X...",
                )
                with gr.Accordion("Generation settings (advanced)", open=False):
                    max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
                    temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature")
                    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")

                run_btn = gr.Button("Run")

            with gr.Column(scale=1):
                ocr_out = gr.Textbox(label="Extracted Text (OCR)", lines=8)
                llm_out = gr.Markdown(label="AI Answer", elem_id="ai-answer")

        run_btn.click(
            run_pipeline,
            inputs=[image_input, question, max_new_tokens, temperature, top_p],
            outputs=[ocr_out, llm_out],
        )

        gr.Examples(
            label="Try these sample prompts (use with your own images)",
            examples=[
                ["", "Summarize the document."],
                ["", "Extract all dates and amounts, then total the amounts."],
                ["", "Explain the equation and solve for x."],
                ["", "Convert the pseudocode in the image to Python."],
            ],
            inputs=[image_input, question],
        )

        gr.Markdown(
            "—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
            "Review the model card before production use."
        )

    return demo


if __name__ == "__main__":
    demo = demo_ui()
    demo.launch()