|
|
import os |
|
|
from typing import Tuple |
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
VisionEncoderDecoderModel, |
|
|
TrOCRProcessor, |
|
|
) |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if hf_token: |
|
|
try: |
|
|
login(token=hf_token) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
TITLE = "Picture to Problem Solver" |
|
|
DESCRIPTION = ( |
|
|
"Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question.\n\n" |
|
|
"⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed") |
|
|
ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID) |
|
|
ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID) |
|
|
ocr_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32 |
|
|
|
|
|
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True) |
|
|
|
|
|
if llm_tokenizer.pad_token_id is None and llm_tokenizer.eos_token_id is not None: |
|
|
llm_tokenizer.pad_token = llm_tokenizer.eos_token |
|
|
|
|
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
LLM_MODEL_ID, |
|
|
dtype=dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map="auto" if device == "cuda" else None, |
|
|
) |
|
|
llm_model.eval() |
|
|
if device == "cpu": |
|
|
llm_model.to(device) |
|
|
|
|
|
eos_token_id = llm_tokenizer.eos_token_id |
|
|
if eos_token_id is None: |
|
|
llm_tokenizer.add_special_tokens({"eos_token": "</s>"}) |
|
|
llm_model.resize_token_embeddings(len(llm_tokenizer)) |
|
|
eos_token_id = llm_tokenizer.eos_token_id |
|
|
|
|
|
SYSTEM_INSTRUCTION = ( |
|
|
"You are a precise, step-by-step technical assistant. " |
|
|
"You excel at math, programming (Python, C++), and scientific reasoning. " |
|
|
"Be concise, show steps when helpful, and avoid hallucinations." |
|
|
) |
|
|
|
|
|
USER_PROMPT_TEMPLATE = ( |
|
|
"Extracted text from the image:\n" |
|
|
"-----------------------------\n" |
|
|
"{ocr_text}\n" |
|
|
"-----------------------------\n" |
|
|
"{question_hint}" |
|
|
) |
|
|
|
|
|
def build_prompt(ocr_text: str, user_question: str) -> str: |
|
|
if user_question and user_question.strip(): |
|
|
q = f"User question: {user_question.strip()}" |
|
|
else: |
|
|
q = "Please summarize the key information and explain any math/code/science content." |
|
|
return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format( |
|
|
ocr_text=(ocr_text or "").strip() or "(no text detected)", |
|
|
question_hint=q, |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def run_pipeline( |
|
|
image: Image.Image, |
|
|
question: str, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.2, |
|
|
top_p: float = 0.9, |
|
|
) -> Tuple[str, str]: |
|
|
if image is None: |
|
|
return "", "Please upload an image." |
|
|
|
|
|
|
|
|
try: |
|
|
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values |
|
|
ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256) |
|
|
extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip() |
|
|
except Exception as e: |
|
|
return "", f"OCR failed: {e}" |
|
|
|
|
|
|
|
|
prompt = build_prompt(extracted_text, question) |
|
|
|
|
|
|
|
|
try: |
|
|
inputs = llm_tokenizer(prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(llm_model.device if device == "cuda" else device) for k, v in inputs.items()} |
|
|
|
|
|
generation_kwargs = dict( |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=temperature > 0, |
|
|
temperature=max(0.0, min(temperature, 1.5)), |
|
|
top_p=max(0.1, min(top_p, 1.0)), |
|
|
eos_token_id=eos_token_id, |
|
|
pad_token_id=llm_tokenizer.pad_token_id if llm_tokenizer.pad_token_id is not None else eos_token_id, |
|
|
) |
|
|
|
|
|
output_ids = llm_model.generate(**inputs, **generation_kwargs) |
|
|
gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
if gen_text.startswith(prompt): |
|
|
gen_text = gen_text[len(prompt):].lstrip() |
|
|
except Exception as e: |
|
|
gen_text = f"LLM inference failed: {e}" |
|
|
|
|
|
return extracted_text, gen_text |
|
|
|
|
|
def demo_ui(): |
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(f"# {TITLE}") |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="pil", label="Upload an image") |
|
|
question = gr.Textbox( |
|
|
label="Ask a question about the image (optional)", |
|
|
placeholder="e.g., Summarize, extract key numbers, explain this formula, convert code to Python...", |
|
|
) |
|
|
with gr.Accordion("Generation settings (advanced)", open=False): |
|
|
max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens") |
|
|
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature") |
|
|
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") |
|
|
|
|
|
run_btn = gr.Button("Run") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
ocr_out = gr.Textbox(label="Extracted Text (OCR)", lines=8) |
|
|
llm_out = gr.Markdown(label="AI Answer", elem_id="ai-answer") |
|
|
|
|
|
run_btn.click( |
|
|
run_pipeline, |
|
|
inputs=[image_input, question, max_new_tokens, temperature, top_p], |
|
|
outputs=[ocr_out, llm_out], |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
"—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. " |
|
|
"Review the model card before production use." |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = demo_ui() |
|
|
demo.launch() |