Spaces:
Build error
Build error
| import os | |
| os.system('pip install -U transformers==4.44.2') | |
| import sys | |
| import shutil | |
| import torch | |
| import base64 | |
| import argparse | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| # == download weights == | |
| tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny') | |
| small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small') | |
| base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base') | |
| os.system("ls -l models/unimernet_tiny") | |
| os.system("ls -l models/unimernet_small") | |
| os.system("ls -l models/unimernet_base") | |
| # == download weights == | |
| sys.path.insert(0, os.path.join(os.getcwd(), "..")) | |
| from unimernet.common.config import Config | |
| import unimernet.tasks as tasks | |
| from unimernet.processors import load_processor | |
| template_html = """<!DOCTYPE html> | |
| <html lang="en" data-lt-installed="true"><head> | |
| <meta charset="UTF-8"> | |
| <title>Title</title> | |
| <script> | |
| const text = | |
| </script> | |
| <style> | |
| #content { | |
| max-width: 800px; | |
| margin: auto; | |
| } | |
| </style> | |
| <script> | |
| let script = document.createElement('script'); | |
| script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js"; | |
| document.head.append(script); | |
| script.onload = function() { | |
| const isLoaded = window.loadMathJax(); | |
| if (isLoaded) { | |
| console.log('Styles loaded!') | |
| } | |
| const el = window.document.getElementById('content-text'); | |
| if (el) { | |
| const options = { | |
| htmlTags: true | |
| }; | |
| const html = window.render(text, options); | |
| el.outerHTML = html; | |
| } | |
| }; | |
| </script> | |
| </head> | |
| <body> | |
| <div id="content"><div id="content-text"></div></div> | |
| </body> | |
| </html> | |
| """ | |
| def latex2html(latex_code): | |
| latex_code = '\\[' + latex_code + '\\]' | |
| latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') | |
| latex_code = latex_code.replace('"', '``').replace('$', '') | |
| latex_code_list = latex_code.split('\n') | |
| gt= '' | |
| for out in latex_code_list: | |
| gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' | |
| gt = gt[:-2] | |
| lines = template_html.split("const text =") | |
| new_web = lines[0] + 'const text =' + gt + lines[1] | |
| return new_web | |
| def load_model_and_processor(cfg_path): | |
| args = argparse.Namespace(cfg_path=cfg_path, options=None) | |
| cfg = Config(args) | |
| task = tasks.setup_task(cfg) | |
| model = task.build_model(cfg) | |
| vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) | |
| return model, vis_processor | |
| def recognize_image(input_img, model_type): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model_type == "base": | |
| model = model_base.to(device) | |
| elif model_type == "small": | |
| model = model_small.to(device) | |
| else: | |
| model = model_tiny.to(device) | |
| model.eval() | |
| if len(input_img.shape) == 3: | |
| input_img = input_img[:, :, ::-1].copy() | |
| img = Image.fromarray(input_img) | |
| image = vis_processor(img).unsqueeze(0).to(device) | |
| output = model.generate({"image": image}) | |
| latex_code = output["pred_str"][0] | |
| html_code = latex2html(latex_code) | |
| encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') | |
| iframe_src = f"data:text/html;base64,{encoded_html}" | |
| iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' | |
| return latex_code, iframe | |
| def gradio_reset(): | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| if __name__ == "__main__": | |
| root_path = os.path.abspath(os.getcwd()) | |
| # == load model == | |
| print("load tiny model ...") | |
| model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml")) | |
| print("load small model ...") | |
| model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml")) | |
| print("load base model ...") | |
| model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml")) | |
| print("== load all models done. ==") | |
| # == load model == | |
| with open("header.html", "r") as file: | |
| header = file.read() | |
| with gr.Blocks() as demo: | |
| gr.HTML(header) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_type = gr.Radio( | |
| choices=["tiny", "small", "base"], | |
| value="tiny", | |
| label="Model Type", | |
| interactive=True, | |
| ) | |
| input_img = gr.Image(label=" ", interactive=True) | |
| with gr.Row(): | |
| clear = gr.Button("Clear") | |
| predict = gr.Button(value="Recognize", interactive=True, variant="primary") | |
| with gr.Accordion("Examples:"): | |
| example_root = os.path.join(os.path.dirname(__file__), "examples") | |
| gr.Examples( | |
| examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if | |
| _.endswith("png")], | |
| inputs=input_img, | |
| ) | |
| with gr.Column(): | |
| gr.Button(value="Predict Result:", interactive=False) | |
| pred_latex = gr.Textbox(label='Predict Latex', interactive=False) | |
| output_html = gr.HTML(label="Rendered html", show_label=True) | |
| clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html]) | |
| predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |