npaleti2002's picture
Update app.py
b3c036b verified
raw
history blame
6.73 kB
import os
from typing import Tuple
import gradio as gr
from PIL import Image
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
VisionEncoderDecoderModel,
TrOCRProcessor,
)
from huggingface_hub import login
import os
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
TITLE = "Picture to Problem Solver"
DESCRIPTION = (
"Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question."
"\n\n⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
)
# ---------------------------
# Load OCR (TrOCR)
# ---------------------------
# Use the "printed" variant for typed/scanned text.
# If you expect handwriting, switch to: microsoft/trocr-base-handwritten
OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
ocr_model.eval()
# ---------------------------
# Load MobileLLM
# ---------------------------
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")
# Device & dtype selection that plays nice on Spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# Keep dtype conservative to avoid OOM on CPU Spaces
torch_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
llm_model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map="auto" if device == "cuda" else None,
)
llm_model.eval()
if device == "cpu":
llm_model.to(device)
# Ensure EOS/BOS tokens exist
eos_token_id = llm_tokenizer.eos_token_id
if eos_token_id is None:
# Fallback: add one if truly missing (rare)
llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
llm_model.resize_token_embeddings(len(llm_tokenizer))
eos_token_id = llm_tokenizer.eos_token_id
SYSTEM_INSTRUCTION = (
"You are a precise, step-by-step technical assistant. "
"You excel at math, programming (Python, C++), and scientific reasoning. "
"Be concise, show steps when helpful, and avoid hallucinations. "
)
USER_PROMPT_TEMPLATE = (
"Extracted text from the image:\n"
"-----------------------------\n"
"{ocr_text}\n"
"-----------------------------\n"
"{question_hint}"
)
def build_prompt(ocr_text: str, user_question: str) -> str:
if user_question and user_question.strip():
q = f"User question: {user_question.strip()}"
else:
q = "Please summarize the key information and explain any math/code/science content."
return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
ocr_text=ocr_text.strip() if ocr_text else "(no text detected)",
question_hint=q,
)
@torch.inference_mode()
def run_pipeline(
image: Image.Image,
question: str,
max_new_tokens: int = 256,
temperature: float = 0.2,
top_p: float = 0.9,
) -> Tuple[str, str]:
"""
Returns:
(extracted_text, model_answer)
"""
if image is None:
return "", "Please upload an image."
# --- OCR ---
# TrOCR wants pixel_values prepared by its processor
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
with torch.inference_mode():
ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()
# --- Build prompt for LLM ---
prompt = build_prompt(extracted_text, question)
# --- LLM Inference ---
inputs = llm_tokenizer(prompt, return_tensors="pt")
if device == "cuda":
inputs = {k: v.to(llm_model.device) for k, v in inputs.items()}
else:
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True if temperature > 0 else False,
temperature=max(0.0, min(temperature, 1.5)),
top_p=max(0.1, min(top_p, 1.0)),
eos_token_id=eos_token_id,
pad_token_id=llm_tokenizer.eos_token_id, # keep decoding clean
)
output_ids = llm_model.generate(**inputs, **generation_kwargs)
# We only want the newly generated part for readability
gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Optional: strip the original prompt if the model echoes it
if gen_text.startswith(prompt):
gen_text = gen_text[len(prompt):].lstrip()
return extracted_text, gen_text
def demo_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload an image")
question = gr.Textbox(
label="Ask a question about the image (optional)",
placeholder="e.g., Summarize, extract key numbers, explain this formula, write Python to do X...",
)
with gr.Accordion("Generation settings (advanced)", open=False):
max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
run_btn = gr.Button("Run")
with gr.Column(scale=1):
ocr_out = gr.Textbox(label="Extracted Text (OCR)", lines=8)
llm_out = gr.Markdown(label="AI Answer", elem_id="ai-answer")
run_btn.click(
run_pipeline,
inputs=[image_input, question, max_new_tokens, temperature, top_p],
outputs=[ocr_out, llm_out],
)
gr.Examples(
label="Try these sample prompts (use with your own images)",
examples=[
["", "Summarize the document."],
["", "Extract all dates and amounts, then total the amounts."],
["", "Explain the equation and solve for x."],
["", "Convert the pseudocode in the image to Python."],
],
inputs=[image_input, question],
)
gr.Markdown(
"—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
"Review the model card before production use."
)
return demo
if __name__ == "__main__":
demo = demo_ui()
demo.launch()