| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from matplotlib.patches import Patch | |
| import io | |
| import cv2 | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import csv | |
| import pandas as pd | |
| from ultralytics import YOLO | |
| import torch | |
| from paddleocr import PaddleOCR | |
| import postprocess | |
| import gradio as gr | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device) | |
| structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device) | |
| ocr_model = PaddleOCR(use_angle_cls=True, lang=uk", det_limit_side_len=1920) # TODO use large det_limit_side_len to get better OCR result | |
| detection_class_names = ['table', 'table rotated'] | |
| structure_class_names = [ | |
| 'table', 'table column', 'table row', 'table column header', | |
| 'table projected row header', 'table spanning cell', 'no object' | |
| ] | |
| structure_class_map = {k: v for v, k in enumerate(structure_class_names)} | |
| structure_class_thresholds = { | |
| "table": 0.5, | |
| "table column": 0.5, | |
| "table row": 0.5, | |
| "table column header": 0.5, | |
| "table projected row header": 0.5, | |
| "table spanning cell": 0.5, | |
| "no object": 10 | |
| } | |
| def table_detection(image): | |
| imgsz = 800 | |
| pred = detection_model.predict(image, imgsz=imgsz) | |
| pred = pred[0].boxes | |
| result = pred.cpu().numpy() | |
| result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] | |
| return result_list | |
| def table_structure(image): | |
| imgsz = 1024 | |
| pred = structure_model.predict(image, imgsz=imgsz) | |
| pred = pred[0].boxes | |
| result = pred.cpu().numpy() | |
| result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] | |
| return result_list | |
| def crop_image(image, detection_result): | |
| # crop_filenames = [] | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| crop_image = image | |
| for i, result in enumerate(detection_result[:1]): # TODO only return first detected table | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| # x1 = max(0, int((min_x-w/2-0.02)*width)) # TODO expand 2% | |
| # y1 = max(0, int((min_y-h/2-0.02)*height)) # TODO expand 2% | |
| # x2 = min(width, int((min_x+w/2+0.02)*width)) # TODO expand 2% | |
| # y2 = min(height, int((min_y+h/2+0.02)*height)) # TODO expand 2% | |
| x1 = max(0, int((min_x-w/2)*width)-10) # TODO expand 10px | |
| y1 = max(0, int((min_y-h/2)*height)-10) # TODO expand 10px | |
| x2 = min(width, int((min_x+w/2)*width)+10) # TODO expand 10px | |
| y2 = min(height, int((min_y+h/2)*height)+10) # TODO expand 10px | |
| # print(x1, y1, x2, y2) | |
| crop_image = image[y1:y2, x1:x2, :] | |
| # crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:] | |
| # crop_filenames.append(crop_filename) | |
| # cv2.imwrite(crop_filename, crop_image) | |
| return crop_image | |
| def convert_stucture(ocr_result, image, structure_result): | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| bboxes = [] | |
| scores = [] | |
| labels = [] | |
| for i, result in enumerate(structure_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| x1 = int((min_x-w/2)*width) | |
| y1 = int((min_y-h/2)*height) | |
| x2 = int((min_x+w/2)*width) | |
| y2 = int((min_y+h/2)*height) | |
| # print(x1, y1, x2, y2) | |
| bboxes.append([x1, y1, x2, y2]) | |
| scores.append(score) | |
| labels.append(class_id) | |
| table_objects = [] | |
| for bbox, score, label in zip(bboxes, scores, labels): | |
| table_objects.append({'bbox': bbox, 'score': score, 'label': label}) | |
| # print('table_objects:', table_objects) | |
| table = {'objects': table_objects, 'page_num': 0} | |
| table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] | |
| if len(table_class_objects) > 1: | |
| table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) | |
| try: | |
| table_bbox = list(table_class_objects[0]['bbox']) | |
| except: | |
| table_bbox = (0,0,1000,1000) | |
| # print('table_class_objects:', table_class_objects) | |
| # print('table_bbox:', table_bbox) | |
| page_tokens = ocr_result | |
| tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] | |
| # print('tokens_in_table:', tokens_in_table) | |
| table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) | |
| return table_structures, cells, confidence_score | |
| def visualize_cells(image, table_structures, cells): | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| empty_image = np.zeros((height, width, 3), np.uint8) | |
| empty_image.fill(255) | |
| empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB)) | |
| draw = ImageDraw.Draw(empty_image) | |
| fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8") | |
| num_cols = len(table_structures['columns']) | |
| num_rows = len(table_structures['rows']) | |
| data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)] | |
| for i, cell in enumerate(cells): | |
| bbox = cell['bbox'] | |
| x1 = int(bbox[0]) | |
| y1 = int(bbox[1]) | |
| x2 = int(bbox[2]) | |
| y2 = int(bbox[3]) | |
| col_num = cell['column_nums'][0] | |
| row_num = cell['row_nums'][0] | |
| spans = cell['spans'] | |
| text = '' | |
| for span in spans: | |
| if 'text' in span: | |
| text += span['text'] | |
| data_rows[row_num][col_num] = text | |
| # print('text:', text) | |
| text_len = len(text) | |
| # print('text_len:', text_len) | |
| cell_width = x2-x1 | |
| # print('cell_width:', cell_width) | |
| num_per_line = cell_width//10 | |
| # print('num_per_line:', num_per_line) | |
| if num_per_line != 0: | |
| line_num = text_len//num_per_line | |
| else: | |
| line_num = 0 | |
| # print('line_num:', line_num) | |
| new_text = text[:num_per_line]+'\n' | |
| for j in range(line_num): | |
| new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n' | |
| # print('new_text:', new_text) | |
| text = new_text | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0)) | |
| # cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) | |
| # cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255)) | |
| # cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) | |
| # cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) | |
| draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0)) | |
| # draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle) | |
| # draw.text((x1, y1), text, (0,0,255), font=fontStyle) | |
| df = pd.DataFrame(data_rows) | |
| df.columns = df.columns.astype(str) | |
| return image, df, df.to_json() | |
| def ocr(image): | |
| result = ocr_model.ocr(image, cls=True) | |
| result = result[0] | |
| new_result = [] | |
| if result is not None: | |
| bounding_boxes = [line[0] for line in result] | |
| txts = [line[1][0] for line in result] | |
| scores = [line[1][1] for line in result] | |
| # print('txts:', txts) | |
| # print('scores:', scores) | |
| # print('bounding_boxes:', bounding_boxes) | |
| for label, bbox in zip(txts, bounding_boxes): | |
| new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label}) | |
| return new_result | |
| def detect_and_crop_table(image): | |
| detection_result = table_detection(image) | |
| # print('detection_result:', detection_result) | |
| cropped_table = crop_image(image, detection_result) | |
| return cropped_table | |
| def recognize_table(image, ocr_result): | |
| structure_result = table_structure(image) | |
| print('structure_result:', structure_result) | |
| table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result) | |
| print('table_structures:', table_structures) | |
| print('cells:', cells) | |
| print('confidence_score:', confidence_score) | |
| image, df, data = visualize_cells(image, table_structures, cells) | |
| return image, df, data | |
| def process_pdf(image): | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| cropped_table = detect_and_crop_table(image) | |
| ocr_result = ocr(cropped_table) | |
| # print('ocr_result:', ocr_result) | |
| image, df, data = recognize_table(cropped_table, ocr_result) | |
| print('df:', df) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| return image, df, data | |
| title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)." | |
| description = """Demo for table extraction with the Table Structure Recognition (Yolov8).""" | |
| examples = [['image.png'], ['mistral_paper.png']] | |
| app = gr.Interface(fn=process_pdf, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")], | |
| title=title, | |
| description=description, | |
| examples=examples) | |
| app.queue() | |
| # app.launch(debug=True, share=True) | |
| app.launch() |