Spaces:
Build error
Build error
Commit
·
5ed8d35
0
Parent(s):
Duplicate from Acodis/Chart2Data
Browse files- .gitattributes +34 -0
- 2304.08069_2.png +0 -0
- README.md +14 -0
- app.py +217 -0
- chartVQA.png +0 -0
- requirements.txt +9 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
2304.08069_2.png
ADDED
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Image2Data
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.27.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: cc-by-nc-sa-4.0
|
| 11 |
+
duplicated_from: Acodis/Chart2Data
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
| 14 |
+
os.system("git clone https://github.com/microsoft/unilm.git; cd unilm; git checkout 9102ed91f8e56baa31d7ae7e09e0ec98e77d779c; cd ..")
|
| 15 |
+
sys.path.append("unilm")
|
| 16 |
+
|
| 17 |
+
from unilm.dit.object_detection.ditod import add_vit_config
|
| 18 |
+
from detectron2.config import CfgNode as CN
|
| 19 |
+
from detectron2.config import get_cfg
|
| 20 |
+
from detectron2.data import MetadataCatalog
|
| 21 |
+
from detectron2.engine import DefaultPredictor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#Plot settings
|
| 25 |
+
sns.set_style("darkgrid")
|
| 26 |
+
palette = sns.color_palette("pastel")
|
| 27 |
+
sns.set_palette(palette)
|
| 28 |
+
plt.switch_backend("Agg")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Load the DiT model config
|
| 33 |
+
cfg = get_cfg()
|
| 34 |
+
add_vit_config(cfg)
|
| 35 |
+
cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml")
|
| 36 |
+
|
| 37 |
+
# Get the model weights
|
| 38 |
+
cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth"
|
| 39 |
+
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 40 |
+
|
| 41 |
+
# Define the model predictor
|
| 42 |
+
predictor = DefaultPredictor(cfg)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Load the DePlot model
|
| 46 |
+
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(cfg.MODEL.DEVICE)
|
| 47 |
+
processor = AutoProcessor.from_pretrained("google/deplot")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def crop_figure(img: Image.Image , threshold: float = 0.5) -> Image.Image:
|
| 52 |
+
"""Prediction function for the figure cropping model using DiT backend.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
img (Image.Image): Input document image.
|
| 56 |
+
threshold (float, optional): Detection threshold. Defaults to 0.5.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
(Image.Image): The cropped figure image.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
| 63 |
+
md.set(thing_classes=["text","title","list","table","figure"])
|
| 64 |
+
|
| 65 |
+
output = predictor(np.array(img))["instances"]
|
| 66 |
+
|
| 67 |
+
boxes, scores, classes = output.pred_boxes.tensor.cpu().numpy(), output.scores.cpu().numpy(), output.pred_classes.cpu().numpy()
|
| 68 |
+
|
| 69 |
+
boxes = boxes[classes == 4] # 4 is the class for figures
|
| 70 |
+
scores = scores[classes == 4]
|
| 71 |
+
if len(boxes) == 0:
|
| 72 |
+
return []
|
| 73 |
+
|
| 74 |
+
print(boxes, scores)
|
| 75 |
+
|
| 76 |
+
# sort boxes by score
|
| 77 |
+
crop_box = boxes[np.argsort(scores)[::-1]][0]
|
| 78 |
+
|
| 79 |
+
# Add white space around the figure
|
| 80 |
+
margin = 0.1
|
| 81 |
+
box_size = crop_box[-2:] - crop_box[:2]
|
| 82 |
+
size = tuple((box_size + np.array([margin, margin]) * box_size).astype(int))
|
| 83 |
+
|
| 84 |
+
new = Image.new('RGB', size, (255, 255, 255))
|
| 85 |
+
image = img.crop(crop_box)
|
| 86 |
+
new.paste(image, (int((size[0] - image.size[0]) / 2), int(((size[1]) - image.size[1]) / 2)))
|
| 87 |
+
|
| 88 |
+
return new
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def extract_tables(image: Image.Image) -> Tuple[str]:
|
| 92 |
+
"""Prediction function for the table extraction model using DePlot backend.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
image (Image.Image): Input figure image.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple[str]: The table title, the table as a pandas dataframe, and the table as an HTML string, if the table was successfully extracted.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
inputs = processor(image, text="Generate a data table using only the data you see in the graph below: ", return_tensors="pt").to(cfg.MODEL.DEVICE)
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
outputs = model.generate(**inputs, max_new_tokens=512)
|
| 104 |
+
decoded = processor.decode(outputs[0], skip_special_tokens=True)
|
| 105 |
+
|
| 106 |
+
print(decoded.replace("<0x0A>", "\n") )
|
| 107 |
+
data = [row.split(" | ") for row in decoded.split("<0x0A>")]
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
if data[0][0].lower().startswith("title"):
|
| 111 |
+
title = data[0][1]
|
| 112 |
+
table = pd.DataFrame(data[2:], columns=data[1])
|
| 113 |
+
else:
|
| 114 |
+
title = "Table"
|
| 115 |
+
table = pd.DataFrame(data[1:], columns=data[0])
|
| 116 |
+
|
| 117 |
+
return title, table, table.to_html()
|
| 118 |
+
except:
|
| 119 |
+
return "Table", list(list()), decoded.replace("<0x0A>", "\n")
|
| 120 |
+
|
| 121 |
+
def update(df: pd.DataFrame, plot_type: str) -> plt.figure:
|
| 122 |
+
"""Update callback for the gradio interface, that updates the plot based on the table data and selected plot type.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
df (pd.DataFrame): The extracted table data.
|
| 126 |
+
plot_type (str): The selected plot type to generate.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
plt.figure: The updated plot.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
plt.close("all")
|
| 133 |
+
df = df.apply(pd.to_numeric, errors="ignore")
|
| 134 |
+
|
| 135 |
+
fig = plt.figure(figsize=(8, 6))
|
| 136 |
+
ax = fig.add_subplot(111)
|
| 137 |
+
|
| 138 |
+
cols = df.columns
|
| 139 |
+
if len(cols) == 0:
|
| 140 |
+
return fig
|
| 141 |
+
|
| 142 |
+
if len(cols) > 1:
|
| 143 |
+
df.set_index(cols[0], inplace=True)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
if plot_type == "Line":
|
| 147 |
+
sns.lineplot(data=df, ax=ax)
|
| 148 |
+
|
| 149 |
+
elif plot_type == "Bar":
|
| 150 |
+
df = df.reset_index()
|
| 151 |
+
if len(cols) == 1:
|
| 152 |
+
sns.barplot(x=df.index, y=df[df.columns[0]], ax=ax)
|
| 153 |
+
elif len(cols) == 2:
|
| 154 |
+
sns.barplot(x=df[df.columns[0]], y=df[df.columns[1]], ax=ax)
|
| 155 |
+
else:
|
| 156 |
+
df = df.melt(id_vars=cols[0], value_vars=cols[1:], value_name="Value")
|
| 157 |
+
sns.barplot(x=df[cols[0]], y=df["Value"], hue=df["variable"], ax=ax)
|
| 158 |
+
elif plot_type == "Scatter":
|
| 159 |
+
sns.scatterplot(data=df, ax=ax)
|
| 160 |
+
|
| 161 |
+
elif plot_type == "Pie":
|
| 162 |
+
ax.pie(df[df.columns[0]], labels=df.index, autopct='%1.1f%%', colors=palette)
|
| 163 |
+
ax.axis('equal')
|
| 164 |
+
except:
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
plt.tight_layout()
|
| 168 |
+
|
| 169 |
+
return fig
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
with gr.Blocks() as demo:
|
| 176 |
+
gr.Markdown("<h1 align=center>Data extraction from charts</h1>")
|
| 177 |
+
gr.Markdown("This Space illustrates an experimental extraction pipeline using two pretrained models:"
|
| 178 |
+
" DiT is used to to find figures in a document and crop them."
|
| 179 |
+
" DePlot is used to extract the data from the plot and covert it to tabular format."
|
| 180 |
+
" Alternatively, you can paste a figure directly into the right Image field for data extraction."
|
| 181 |
+
" Finally, you can re-plot the extracted table using the Plot Type selector. And copy the HTML code to paste it elsewhere.")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
with gr.Row() as row1:
|
| 185 |
+
input = gr.Image(image_mode="RGB", label="Document Page", type='pil')
|
| 186 |
+
cropped = gr.Image(image_mode="RGB", label="Cropped Image", type='pil')
|
| 187 |
+
with gr.Row() as row12:
|
| 188 |
+
crop_btn = gr.Button("Crop Figure")
|
| 189 |
+
extract_btn = gr.Button("Extract")
|
| 190 |
+
with gr.Row() as row13:
|
| 191 |
+
gr.Examples(["./2304.08069_2.png"], input)
|
| 192 |
+
gr.Examples(["./chartVQA.png"], cropped)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
title = gr.Textbox(label="Title")
|
| 196 |
+
with gr.Row() as row2:
|
| 197 |
+
|
| 198 |
+
with gr.Column() as col1:
|
| 199 |
+
tab_data = gr.DataFrame(label="Table")
|
| 200 |
+
plot_type = gr.Radio(["Line", "Bar", "Scatter", "Pie"], label="Plot Type", default="Line")
|
| 201 |
+
plot_btn = gr.Button("Plot")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
display = gr.Plot()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
with gr.Row() as row3:
|
| 208 |
+
html_data = gr.Textbox(label="HTML copy-paste").style(show_copy_button=True, copy_button_label="Copy to clipboard")
|
| 209 |
+
|
| 210 |
+
crop_btn.click(crop_figure, input, [cropped])
|
| 211 |
+
extract_btn.click(extract_tables, cropped, [title, tab_data, html_data])
|
| 212 |
+
|
| 213 |
+
plot_btn.click(update, [tab_data, plot_type], display)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
demo.launch()
|
chartVQA.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.0
|
| 2 |
+
torchvision==0.15.1
|
| 3 |
+
transformers==4.28.1
|
| 4 |
+
timm==0.6.13
|
| 5 |
+
gradio==3.5
|
| 6 |
+
seaborn==0.12.2
|
| 7 |
+
shapely==2.0.1
|
| 8 |
+
Pillow
|
| 9 |
+
scipy
|