npaleti2002 commited on
Commit
dc4a84a
·
verified ·
1 Parent(s): 8f66e16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ VisionEncoderDecoderModel,
12
+ TrOCRProcessor,
13
+ )
14
+
15
+
16
+ TITLE = "Picture to Problem Solver"
17
+ DESCRIPTION = (
18
+ "Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question."
19
+ "\n\n⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
20
+ )
21
+
22
+ # ---------------------------
23
+ # Load OCR (TrOCR)
24
+ # ---------------------------
25
+ # Use the "printed" variant for typed/scanned text.
26
+ # If you expect handwriting, switch to: microsoft/trocr-base-handwritten
27
+ OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
28
+ ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
29
+ ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
30
+ ocr_model.eval()
31
+
32
+ # ---------------------------
33
+ # Load MobileLLM
34
+ # ---------------------------
35
+ LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")
36
+
37
+ # Device & dtype selection that plays nice on Spaces
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ # Keep dtype conservative to avoid OOM on CPU Spaces
40
+ torch_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
41
+
42
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
43
+ llm_model = AutoModelForCausalLM.from_pretrained(
44
+ LLM_MODEL_ID,
45
+ torch_dtype=torch_dtype,
46
+ low_cpu_mem_usage=True,
47
+ device_map="auto" if device == "cuda" else None,
48
+ )
49
+ llm_model.eval()
50
+ if device == "cpu":
51
+ llm_model.to(device)
52
+
53
+ # Ensure EOS/BOS tokens exist
54
+ eos_token_id = llm_tokenizer.eos_token_id
55
+ if eos_token_id is None:
56
+ # Fallback: add one if truly missing (rare)
57
+ llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
58
+ llm_model.resize_token_embeddings(len(llm_tokenizer))
59
+ eos_token_id = llm_tokenizer.eos_token_id
60
+
61
+
62
+ SYSTEM_INSTRUCTION = (
63
+ "You are a precise, step-by-step technical assistant. "
64
+ "You excel at math, programming (Python, C++), and scientific reasoning. "
65
+ "Be concise, show steps when helpful, and avoid hallucinations. "
66
+ )
67
+
68
+ USER_PROMPT_TEMPLATE = (
69
+ "Extracted text from the image:\n"
70
+ "-----------------------------\n"
71
+ "{ocr_text}\n"
72
+ "-----------------------------\n"
73
+ "{question_hint}"
74
+ )
75
+
76
+ def build_prompt(ocr_text: str, user_question: str) -> str:
77
+ if user_question and user_question.strip():
78
+ q = f"User question: {user_question.strip()}"
79
+ else:
80
+ q = "Please summarize the key information and explain any math/code/science content."
81
+
82
+ return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
83
+ ocr_text=ocr_text.strip() if ocr_text else "(no text detected)",
84
+ question_hint=q,
85
+ )
86
+
87
+
88
+ @torch.inference_mode()
89
+ def run_pipeline(
90
+ image: Image.Image,
91
+ question: str,
92
+ max_new_tokens: int = 256,
93
+ temperature: float = 0.2,
94
+ top_p: float = 0.9,
95
+ ) -> Tuple[str, str]:
96
+ """
97
+ Returns:
98
+ (extracted_text, model_answer)
99
+ """
100
+ if image is None:
101
+ return "", "Please upload an image."
102
+
103
+ # --- OCR ---
104
+ # TrOCR wants pixel_values prepared by its processor
105
+ pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
106
+ with torch.inference_mode():
107
+ ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
108
+ extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()
109
+
110
+ # --- Build prompt for LLM ---
111
+ prompt = build_prompt(extracted_text, question)
112
+
113
+ # --- LLM Inference ---
114
+ inputs = llm_tokenizer(prompt, return_tensors="pt")
115
+ if device == "cuda":
116
+ inputs = {k: v.to(llm_model.device) for k, v in inputs.items()}
117
+ else:
118
+ inputs = {k: v.to(device) for k, v in inputs.items()}
119
+
120
+ generation_kwargs = dict(
121
+ max_new_tokens=max_new_tokens,
122
+ do_sample=True if temperature > 0 else False,
123
+ temperature=max(0.0, min(temperature, 1.5)),
124
+ top_p=max(0.1, min(top_p, 1.0)),
125
+ eos_token_id=eos_token_id,
126
+ pad_token_id=llm_tokenizer.eos_token_id, # keep decoding clean
127
+ )
128
+
129
+ output_ids = llm_model.generate(**inputs, **generation_kwargs)
130
+ # We only want the newly generated part for readability
131
+ gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
132
+
133
+ # Optional: strip the original prompt if the model echoes it
134
+ if gen_text.startswith(prompt):
135
+ gen_text = gen_text[len(prompt):].lstrip()
136
+
137
+ return extracted_text, gen_text
138
+
139
+
140
+ def demo_ui():
141
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
142
+ gr.Markdown(f"# {TITLE}")
143
+ gr.Markdown(DESCRIPTION)
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=1):
147
+ image_input = gr.Image(type="pil", label="Upload an image")
148
+ question = gr.Textbox(
149
+ label="Ask a question about the image (optional)",
150
+ placeholder="e.g., Summarize, extract key numbers, explain this formula, write Python to do X...",
151
+ )
152
+ with gr.Accordion("Generation settings (advanced)", open=False):
153
+ max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
154
+ temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature")
155
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
156
+
157
+ run_btn = gr.Button("Run")
158
+
159
+ with gr.Column(scale=1):
160
+ ocr_out = gr.Textbox(label="Extracted Text (OCR)", lines=8)
161
+ llm_out = gr.Markdown(label="AI Answer", elem_id="ai-answer")
162
+
163
+ run_btn.click(
164
+ run_pipeline,
165
+ inputs=[image_input, question, max_new_tokens, temperature, top_p],
166
+ outputs=[ocr_out, llm_out],
167
+ )
168
+
169
+ gr.Examples(
170
+ label="Try these sample prompts (use with your own images)",
171
+ examples=[
172
+ ["", "Summarize the document."],
173
+ ["", "Extract all dates and amounts, then total the amounts."],
174
+ ["", "Explain the equation and solve for x."],
175
+ ["", "Convert the pseudocode in the image to Python."],
176
+ ],
177
+ inputs=[image_input, question],
178
+ )
179
+
180
+ gr.Markdown(
181
+ "—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
182
+ "Review the model card before production use."
183
+ )
184
+
185
+ return demo
186
+
187
+
188
+ if __name__ == "__main__":
189
+ demo = demo_ui()
190
+ demo.launch()