Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import, division, print_function | |
| import cv2 | |
| import numpy as np | |
| import paddle | |
| from .locality_aware_nms import nms_locality | |
| class EASTPostProcess(object): | |
| """ | |
| The post process for EAST. | |
| """ | |
| def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs): | |
| self.score_thresh = score_thresh | |
| self.cover_thresh = cover_thresh | |
| self.nms_thresh = nms_thresh | |
| def restore_rectangle_quad(self, origin, geometry): | |
| """ | |
| Restore rectangle from quadrangle. | |
| """ | |
| # quad | |
| origin_concat = np.concatenate( | |
| (origin, origin, origin, origin), axis=1 | |
| ) # (n, 8) | |
| pred_quads = origin_concat - geometry | |
| pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2) | |
| return pred_quads | |
| def detect( | |
| self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2 | |
| ): | |
| """ | |
| restore text boxes from score map and geo map | |
| """ | |
| score_map = score_map[0] | |
| geo_map = np.swapaxes(geo_map, 1, 0) | |
| geo_map = np.swapaxes(geo_map, 1, 2) | |
| # filter the score map | |
| xy_text = np.argwhere(score_map > score_thresh) | |
| if len(xy_text) == 0: | |
| return [] | |
| # sort the text boxes via the y axis | |
| xy_text = xy_text[np.argsort(xy_text[:, 0])] | |
| # restore quad proposals | |
| text_box_restored = self.restore_rectangle_quad( | |
| xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :] | |
| ) | |
| boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) | |
| boxes[:, :8] = text_box_restored.reshape((-1, 8)) | |
| boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] | |
| try: | |
| import lanms | |
| boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) | |
| except: | |
| print( | |
| "you should install lanms by pip3 install lanms-nova to speed up nms_locality" | |
| ) | |
| boxes = nms_locality(boxes.astype(np.float64), nms_thresh) | |
| if boxes.shape[0] == 0: | |
| return [] | |
| # Here we filter some low score boxes by the average score map, | |
| # this is different from the orginal paper. | |
| for i, box in enumerate(boxes): | |
| mask = np.zeros_like(score_map, dtype=np.uint8) | |
| cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) | |
| boxes[i, 8] = cv2.mean(score_map, mask)[0] | |
| boxes = boxes[boxes[:, 8] > cover_thresh] | |
| return boxes | |
| def sort_poly(self, p): | |
| """ | |
| Sort polygons. | |
| """ | |
| min_axis = np.argmin(np.sum(p, axis=1)) | |
| p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] | |
| if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): | |
| return p | |
| else: | |
| return p[[0, 3, 2, 1]] | |
| def __call__(self, outs_dict, shape_list): | |
| score_list = outs_dict["f_score"] | |
| geo_list = outs_dict["f_geo"] | |
| if isinstance(score_list, paddle.Tensor): | |
| score_list = score_list.numpy() | |
| geo_list = geo_list.numpy() | |
| img_num = len(shape_list) | |
| dt_boxes_list = [] | |
| for ino in range(img_num): | |
| score = score_list[ino] | |
| geo = geo_list[ino] | |
| boxes = self.detect( | |
| score_map=score, | |
| geo_map=geo, | |
| score_thresh=self.score_thresh, | |
| cover_thresh=self.cover_thresh, | |
| nms_thresh=self.nms_thresh, | |
| ) | |
| boxes_norm = [] | |
| if len(boxes) > 0: | |
| h, w = score.shape[1:] | |
| src_h, src_w, ratio_h, ratio_w = shape_list[ino] | |
| boxes = boxes[:, :8].reshape((-1, 4, 2)) | |
| boxes[:, :, 0] /= ratio_w | |
| boxes[:, :, 1] /= ratio_h | |
| for i_box, box in enumerate(boxes): | |
| box = self.sort_poly(box.astype(np.int32)) | |
| if ( | |
| np.linalg.norm(box[0] - box[1]) < 5 | |
| or np.linalg.norm(box[3] - box[0]) < 5 | |
| ): | |
| continue | |
| boxes_norm.append(box) | |
| dt_boxes_list.append({"points": np.array(boxes_norm)}) | |
| return dt_boxes_list | |