Spaces:
Build error
Build error
| from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from typing import Tuple | |
| from PIL import Image | |
| import os | |
| import sys | |
| os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
| os.system("git clone https://github.com/microsoft/unilm.git; cd unilm; git checkout 9102ed91f8e56baa31d7ae7e09e0ec98e77d779c; cd ..") | |
| sys.path.append("unilm") | |
| from unilm.dit.object_detection.ditod import add_vit_config | |
| from detectron2.config import CfgNode as CN | |
| from detectron2.config import get_cfg | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.engine import DefaultPredictor | |
| #Plot settings | |
| sns.set_style("darkgrid") | |
| palette = sns.color_palette("pastel") | |
| sns.set_palette(palette) | |
| plt.switch_backend("Agg") | |
| # Load the DiT model config | |
| cfg = get_cfg() | |
| add_vit_config(cfg) | |
| cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml") | |
| # Get the model weights | |
| cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" | |
| cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Define the model predictor | |
| predictor = DefaultPredictor(cfg) | |
| # Load the DePlot model | |
| model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(cfg.MODEL.DEVICE) | |
| processor = AutoProcessor.from_pretrained("google/deplot") | |
| def crop_figure(img: Image.Image , threshold: float = 0.5) -> Image.Image: | |
| """Prediction function for the figure cropping model using DiT backend. | |
| Args: | |
| img (Image.Image): Input document image. | |
| threshold (float, optional): Detection threshold. Defaults to 0.5. | |
| Returns: | |
| (Image.Image): The cropped figure image. | |
| """ | |
| md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
| md.set(thing_classes=["text","title","list","table","figure"]) | |
| output = predictor(np.array(img))["instances"] | |
| boxes, scores, classes = output.pred_boxes.tensor.cpu().numpy(), output.scores.cpu().numpy(), output.pred_classes.cpu().numpy() | |
| boxes = boxes[classes == 4] # 4 is the class for figures | |
| scores = scores[classes == 4] | |
| if len(boxes) == 0: | |
| return [] | |
| print(boxes, scores) | |
| # sort boxes by score | |
| crop_box = boxes[np.argsort(scores)[::-1]][0] | |
| # Add white space around the figure | |
| margin = 0.1 | |
| box_size = crop_box[-2:] - crop_box[:2] | |
| size = tuple((box_size + np.array([margin, margin]) * box_size).astype(int)) | |
| new = Image.new('RGB', size, (255, 255, 255)) | |
| image = img.crop(crop_box) | |
| new.paste(image, (int((size[0] - image.size[0]) / 2), int(((size[1]) - image.size[1]) / 2))) | |
| return new | |
| def extract_tables(image: Image.Image) -> Tuple[str]: | |
| """Prediction function for the table extraction model using DePlot backend. | |
| Args: | |
| image (Image.Image): Input figure image. | |
| Returns: | |
| Tuple[str]: The table title, the table as a pandas dataframe, and the table as an HTML string, if the table was successfully extracted. | |
| """ | |
| inputs = processor(image, text="Generate a data table using only the data you see in the graph below: ", return_tensors="pt").to(cfg.MODEL.DEVICE) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=512) | |
| decoded = processor.decode(outputs[0], skip_special_tokens=True) | |
| print(decoded.replace("<0x0A>", "\n") ) | |
| data = [row.split(" | ") for row in decoded.split("<0x0A>")] | |
| try: | |
| if data[0][0].lower().startswith("title"): | |
| title = data[0][1] | |
| table = pd.DataFrame(data[2:], columns=data[1]) | |
| else: | |
| title = "Table" | |
| table = pd.DataFrame(data[1:], columns=data[0]) | |
| return title, table, table.to_html() | |
| except: | |
| return "Table", list(list()), decoded.replace("<0x0A>", "\n") | |
| def update(df: pd.DataFrame, plot_type: str) -> plt.figure: | |
| """Update callback for the gradio interface, that updates the plot based on the table data and selected plot type. | |
| Args: | |
| df (pd.DataFrame): The extracted table data. | |
| plot_type (str): The selected plot type to generate. | |
| Returns: | |
| plt.figure: The updated plot. | |
| """ | |
| plt.close("all") | |
| df = df.apply(pd.to_numeric, errors="ignore") | |
| fig = plt.figure(figsize=(8, 6)) | |
| ax = fig.add_subplot(111) | |
| cols = df.columns | |
| if len(cols) == 0: | |
| return fig | |
| if len(cols) > 1: | |
| df.set_index(cols[0], inplace=True) | |
| try: | |
| if plot_type == "Line": | |
| sns.lineplot(data=df, ax=ax) | |
| elif plot_type == "Bar": | |
| df = df.reset_index() | |
| if len(cols) == 1: | |
| sns.barplot(x=df.index, y=df[df.columns[0]], ax=ax) | |
| elif len(cols) == 2: | |
| sns.barplot(x=df[df.columns[0]], y=df[df.columns[1]], ax=ax) | |
| else: | |
| df = df.melt(id_vars=cols[0], value_vars=cols[1:], value_name="Value") | |
| sns.barplot(x=df[cols[0]], y=df["Value"], hue=df["variable"], ax=ax) | |
| elif plot_type == "Scatter": | |
| sns.scatterplot(data=df, ax=ax) | |
| elif plot_type == "Pie": | |
| ax.pie(df[df.columns[0]], labels=df.index, autopct='%1.1f%%', colors=palette) | |
| ax.axis('equal') | |
| except: | |
| pass | |
| plt.tight_layout() | |
| return fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 align=center>Data extraction from charts</h1>") | |
| gr.Markdown("This Space illustrates an experimental extraction pipeline using two pretrained models:" | |
| " DiT is used to to find figures in a document and crop them." | |
| " DePlot is used to extract the data from the plot and covert it to tabular format." | |
| " Alternatively, you can paste a figure directly into the right Image field for data extraction." | |
| " Finally, you can re-plot the extracted table using the Plot Type selector. And copy the HTML code to paste it elsewhere.") | |
| with gr.Row() as row1: | |
| input = gr.Image(image_mode="RGB", label="Document Page", type='pil') | |
| cropped = gr.Image(image_mode="RGB", label="Cropped Image", type='pil') | |
| with gr.Row() as row12: | |
| crop_btn = gr.Button("Crop Figure") | |
| extract_btn = gr.Button("Extract") | |
| with gr.Row() as row13: | |
| gr.Examples(["./2304.08069_2.png"], input) | |
| gr.Examples(["./chartVQA.png"], cropped) | |
| title = gr.Textbox(label="Title") | |
| with gr.Row() as row2: | |
| with gr.Column() as col1: | |
| tab_data = gr.DataFrame(label="Table") | |
| plot_type = gr.Radio(["Line", "Bar", "Scatter", "Pie"], label="Plot Type", default="Line") | |
| plot_btn = gr.Button("Plot") | |
| display = gr.Plot() | |
| with gr.Row() as row3: | |
| html_data = gr.Textbox(label="HTML copy-paste").style(show_copy_button=True, copy_button_label="Copy to clipboard") | |
| crop_btn.click(crop_figure, input, [cropped]) | |
| extract_btn.click(extract_tables, cropped, [title, tab_data, html_data]) | |
| plot_btn.click(update, [tab_data, plot_type], display) | |
| if __name__ == "__main__": | |
| demo.launch() |