Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	PaddleOCR fast and simplified inference
#4
by
						
Goodsea
	
							
						- opened
							
					
This view is limited to 50 files because it contains too many changes. 
				See the raw diff here.
- .gitattributes +36 -0
- README.md +3 -2
- app.py +134 -84
- db_utils.py +0 -41
- .gitignore → ocr/.gitignore +3 -36
- ocr/README.md +1 -0
- ocr/__init__.py +0 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdiparams +3 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info +0 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdmodel +3 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams +3 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info +0 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel +3 -0
- ocr/detector.py +248 -0
- ocr/inference.py +68 -0
- ocr/postprocess/__init__.py +66 -0
- ocr/postprocess/cls_postprocess.py +30 -0
- ocr/postprocess/db_postprocess.py +207 -0
- ocr/postprocess/east_postprocess.py +122 -0
- ocr/postprocess/extract_textpoint_fast.py +464 -0
- ocr/postprocess/extract_textpoint_slow.py +608 -0
- ocr/postprocess/fce_postprocess.py +234 -0
- ocr/postprocess/locality_aware_nms.py +198 -0
- ocr/postprocess/pg_postprocess.py +189 -0
- ocr/postprocess/poly_nms.py +132 -0
- ocr/postprocess/pse_postprocess/__init__.py +1 -0
- ocr/postprocess/pse_postprocess/pse/__init__.py +20 -0
- ocr/postprocess/pse_postprocess/pse/pse.pyx +72 -0
- ocr/postprocess/pse_postprocess/pse/setup.py +19 -0
- ocr/postprocess/pse_postprocess/pse_postprocess.py +100 -0
- ocr/postprocess/rec_postprocess.py +731 -0
- ocr/postprocess/sast_postprocess.py +355 -0
- ocr/postprocess/vqa_token_re_layoutlm_postprocess.py +36 -0
- ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +96 -0
- ocr/ppocr/__init__.py +0 -0
- ocr/ppocr/data/__init__.py +79 -0
- ocr/ppocr/data/collate_fn.py +59 -0
- ocr/ppocr/data/imaug/ColorJitter.py +14 -0
- ocr/ppocr/data/imaug/__init__.py +61 -0
- ocr/ppocr/data/imaug/copy_paste.py +167 -0
- ocr/ppocr/data/imaug/east_process.py +427 -0
- ocr/ppocr/data/imaug/fce_aug.py +563 -0
- ocr/ppocr/data/imaug/fce_targets.py +671 -0
- ocr/ppocr/data/imaug/gen_table_mask.py +228 -0
- ocr/ppocr/data/imaug/iaa_augment.py +72 -0
- ocr/ppocr/data/imaug/label_ops.py +1046 -0
- ocr/ppocr/data/imaug/make_border_map.py +155 -0
- ocr/ppocr/data/imaug/make_pse_gt.py +88 -0
- ocr/ppocr/data/imaug/make_shrink_map.py +100 -0
- ocr/ppocr/data/imaug/operators.py +458 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *.pdiparams filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.pdmodel filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,13 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: Deprem  | 
| 3 | 
             
            emoji: 👀
         | 
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 3.17.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            -
            pinned:  | 
|  | |
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Deprem Ocr 2
         | 
| 3 | 
             
            emoji: 👀
         | 
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 3.17.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            duplicated_from: mertcobanov/deprem-ocr-2
         | 
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,81 +1,157 @@ | |
| 1 | 
            -
            from PIL import ImageFilter, Image
         | 
| 2 | 
            -
            from easyocr import Reader
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
            -
            import  | 
|  | |
| 5 | 
             
            import openai
         | 
| 6 | 
             
            import ast
         | 
| 7 | 
            -
            from transformers import pipeline
         | 
| 8 | 
             
            import os
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 |  | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
|  | |
|  | |
| 12 |  | 
| 13 | 
             
            openai.api_key = os.getenv("API_KEY")
         | 
| 14 | 
            -
            reader = Reader(["tr"])
         | 
| 15 |  | 
|  | |
|  | |
|  | |
| 16 |  | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
                 | 
| 20 | 
            -
                 | 
| 21 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 22 |  | 
| 23 |  | 
| 24 | 
            -
            # Submit button
         | 
| 25 | 
             
            def get_parsed_address(input_img):
         | 
| 26 |  | 
| 27 | 
             
                address_full_text = get_text(input_img)
         | 
| 28 | 
            -
                return  | 
| 29 |  | 
| 30 |  | 
| 31 | 
            -
            def  | 
| 32 | 
            -
                 | 
| 33 | 
            -
                 | 
| 34 | 
            -
                 | 
|  | |
| 35 |  | 
| 36 |  | 
| 37 | 
            -
            def  | 
| 38 | 
            -
                 | 
| 39 |  | 
|  | |
|  | |
|  | |
|  | |
| 40 |  | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 43 |  | 
| 44 |  | 
| 45 | 
             
            def text_dict(input):
         | 
| 46 | 
             
                eval_result = ast.literal_eval(input)
         | 
|  | |
|  | |
| 47 | 
             
                return (
         | 
| 48 | 
            -
                    str(eval_result[" | 
| 49 | 
            -
                    str(eval_result[" | 
| 50 | 
            -
                    str(eval_result[" | 
| 51 | 
            -
                    str(eval_result[" | 
| 52 | 
            -
                    str(eval_result[" | 
|  | |
|  | |
| 53 | 
             
                    str(eval_result["no"]),
         | 
| 54 | 
            -
                    str(eval_result["ad-soyad"]),
         | 
| 55 | 
            -
                    str(eval_result["dis kapi no"]),
         | 
| 56 | 
             
                )
         | 
| 57 |  | 
| 58 |  | 
| 59 | 
            -
            def  | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 |  | 
| 65 | 
            -
                 | 
| 66 | 
            -
                     | 
| 67 | 
            -
                     | 
| 68 | 
            -
                     | 
| 69 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 70 | 
             
                resp["input"] = ocr_input
         | 
| 71 | 
            -
                dict_keys = [ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 72 | 
             
                for key in dict_keys:
         | 
| 73 | 
             
                    if key not in resp.keys():
         | 
| 74 | 
             
                        resp[key] = ""
         | 
| 75 | 
             
                return resp
         | 
| 76 |  | 
| 77 |  | 
| 78 | 
            -
            # User Interface
         | 
| 79 | 
             
            with gr.Blocks() as demo:
         | 
| 80 | 
             
                gr.Markdown(
         | 
| 81 | 
             
                    """
         | 
| @@ -86,68 +162,42 @@ with gr.Blocks() as demo: | |
| 86 | 
             
                    "Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın."
         | 
| 87 | 
             
                )
         | 
| 88 | 
             
                with gr.Row():
         | 
| 89 | 
            -
                     | 
| 90 | 
            -
             | 
| 91 | 
            -
                        img_area_button = gr.Button(value="Görüntüyü İşle", label="Submit")
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    with gr.Column():
         | 
| 94 | 
            -
                        text_area = gr.Textbox(label="Metin yükleyin 👇 ", lines=8)
         | 
| 95 | 
            -
                        text_area_button = gr.Button(value="Metni Yükle", label="Submit")
         | 
| 96 | 
            -
             | 
| 97 | 
             
                open_api_text = gr.Textbox(label="Tam Adres")
         | 
| 98 | 
            -
             | 
| 99 | 
             
                with gr.Column():
         | 
| 100 | 
             
                    with gr.Row():
         | 
| 101 | 
            -
                         | 
| 102 | 
            -
                         | 
| 103 | 
             
                    with gr.Row():
         | 
| 104 | 
            -
                         | 
| 105 | 
            -
             | 
| 106 | 
            -
                        )
         | 
| 107 | 
            -
                        sokak = gr.Textbox(
         | 
| 108 | 
            -
                            label="Sokak/Cadde/Bulvar", interactive=True, show_progress=False
         | 
| 109 | 
            -
                        )
         | 
| 110 | 
             
                    with gr.Row():
         | 
| 111 | 
            -
                         | 
| 112 | 
             
                    with gr.Row():
         | 
| 113 | 
            -
                         | 
| 114 | 
            -
             | 
| 115 | 
            -
                        )
         | 
| 116 | 
            -
                        apartman = gr.Textbox(label="apartman", interactive=True, show_progress=False)
         | 
| 117 | 
             
                    with gr.Row():
         | 
| 118 | 
            -
                         | 
| 119 |  | 
| 120 | 
            -
                 | 
| 121 | 
             
                    get_parsed_address,
         | 
| 122 | 
             
                    inputs=img_area,
         | 
| 123 | 
             
                    outputs=open_api_text,
         | 
| 124 | 
            -
                    api_name=" | 
| 125 | 
             
                )
         | 
| 126 |  | 
| 127 | 
            -
                 | 
| 128 | 
            -
                     | 
| 129 | 
             
                )
         | 
| 130 |  | 
| 131 | 
            -
             | 
| 132 | 
             
                open_api_text.change(
         | 
| 133 | 
             
                    text_dict,
         | 
| 134 | 
             
                    open_api_text,
         | 
| 135 | 
            -
                    [ | 
| 136 | 
            -
                )
         | 
| 137 | 
            -
                ocr_button = gr.Button(value="Sadece OCR kullan")
         | 
| 138 | 
            -
                ocr_button.click(
         | 
| 139 | 
            -
                    get_text,
         | 
| 140 | 
            -
                    inputs=img_area,
         | 
| 141 | 
            -
                    outputs=text_area,
         | 
| 142 | 
            -
                    api_name="get-ocr-output",
         | 
| 143 | 
             
                )
         | 
| 144 | 
            -
                submit_button = gr.Button(value="Veriyi Birimlere Yolla")
         | 
| 145 | 
            -
                submit_button.click(save_deta_db, open_api_text)
         | 
| 146 | 
            -
                done_text = gr.Textbox(label="Done", value="Not Done", visible=False)
         | 
| 147 | 
            -
                submit_button.click(update_component, outputs=done_text)
         | 
| 148 | 
            -
                for txt in [il, ilce, mahalle, sokak, apartman, no, ad_soyad, dis_kapi_no]:
         | 
| 149 | 
            -
                    submit_button.click(fn=clear_textbox, inputs=txt, outputs=txt)
         | 
| 150 |  | 
| 151 |  | 
| 152 | 
             
            if __name__ == "__main__":
         | 
| 153 | 
            -
                demo.launch()
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import csv
         | 
| 4 | 
             
            import openai
         | 
| 5 | 
             
            import ast
         | 
|  | |
| 6 | 
             
            import os
         | 
| 7 | 
            +
            from deta import Deta
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from ocr import utility
         | 
| 11 | 
            +
            from ocr.detector import TextDetector
         | 
| 12 | 
            +
            from ocr.recognizer import TextRecognizer
         | 
| 13 |  | 
| 14 | 
            +
            # Global Detector and Recognizer
         | 
| 15 | 
            +
            args = utility.parse_args()
         | 
| 16 | 
            +
            text_recognizer = TextRecognizer(args)
         | 
| 17 | 
            +
            text_detector = TextDetector(args)
         | 
| 18 |  | 
| 19 | 
             
            openai.api_key = os.getenv("API_KEY")
         | 
|  | |
| 20 |  | 
| 21 | 
            +
            args = utility.parse_args()
         | 
| 22 | 
            +
            text_recognizer = TextRecognizer(args)
         | 
| 23 | 
            +
            text_detector = TextDetector(args)
         | 
| 24 |  | 
| 25 | 
            +
             | 
| 26 | 
            +
            def apply_ocr(img):
         | 
| 27 | 
            +
                # Detect text regions
         | 
| 28 | 
            +
                dt_boxes, _ = text_detector(img)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                boxes = []
         | 
| 31 | 
            +
                for box in dt_boxes:
         | 
| 32 | 
            +
                    p1, p2, p3, p4 = box
         | 
| 33 | 
            +
                    x1 = min(p1[0], p2[0], p3[0], p4[0])
         | 
| 34 | 
            +
                    y1 = min(p1[1], p2[1], p3[1], p4[1])
         | 
| 35 | 
            +
                    x2 = max(p1[0], p2[0], p3[0], p4[0])
         | 
| 36 | 
            +
                    y2 = max(p1[1], p2[1], p3[1], p4[1])
         | 
| 37 | 
            +
                    boxes.append([x1, y1, x2, y2])
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # Recognize text
         | 
| 40 | 
            +
                img_list = []
         | 
| 41 | 
            +
                for i in range(len(boxes)):
         | 
| 42 | 
            +
                    x1, y1, x2, y2 = map(int, boxes[i])
         | 
| 43 | 
            +
                    img_list.append(img.copy()[y1:y2, x1:x2])
         | 
| 44 | 
            +
                img_list.reverse()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                rec_res, _ = text_recognizer(img_list)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Postprocess
         | 
| 49 | 
            +
                total_text = ""
         | 
| 50 | 
            +
                for i in range(len(rec_res)):
         | 
| 51 | 
            +
                    total_text += rec_res[i][0] + " "
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                total_text = total_text.strip()
         | 
| 54 | 
            +
                return total_text
         | 
| 55 |  | 
| 56 |  | 
|  | |
| 57 | 
             
            def get_parsed_address(input_img):
         | 
| 58 |  | 
| 59 | 
             
                address_full_text = get_text(input_img)
         | 
| 60 | 
            +
                return openai_response(address_full_text)
         | 
| 61 |  | 
| 62 |  | 
| 63 | 
            +
            def get_text(input_img):
         | 
| 64 | 
            +
                input_img = np.array(input_img)
         | 
| 65 | 
            +
                result = apply_ocr(input_img)
         | 
| 66 | 
            +
                print(result)
         | 
| 67 | 
            +
                return " ".join(result)
         | 
| 68 |  | 
| 69 |  | 
| 70 | 
            +
            def save_csv(mahalle, il, sokak, apartman):
         | 
| 71 | 
            +
                adres_full = [mahalle, il, sokak, apartman]
         | 
| 72 |  | 
| 73 | 
            +
                with open("adress_book.csv", "a", encoding="utf-8") as f:
         | 
| 74 | 
            +
                    write = csv.writer(f)
         | 
| 75 | 
            +
                    write.writerow(adres_full)
         | 
| 76 | 
            +
                return adres_full
         | 
| 77 |  | 
| 78 | 
            +
             | 
| 79 | 
            +
            def get_json(mahalle, il, sokak, apartman):
         | 
| 80 | 
            +
                adres = {"mahalle": mahalle, "il": il, "sokak": sokak, "apartman": apartman}
         | 
| 81 | 
            +
                dump = json.dumps(adres, indent=4, ensure_ascii=False)
         | 
| 82 | 
            +
                return dump
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def write_db(data_dict):
         | 
| 86 | 
            +
                # 2) initialize with a project key
         | 
| 87 | 
            +
                deta_key = os.getenv("DETA_KEY")
         | 
| 88 | 
            +
                deta = Deta(deta_key)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                # 3) create and use as many DBs as you want!
         | 
| 91 | 
            +
                users = deta.Base("deprem-ocr")
         | 
| 92 | 
            +
                users.insert(data_dict)
         | 
| 93 |  | 
| 94 |  | 
| 95 | 
             
            def text_dict(input):
         | 
| 96 | 
             
                eval_result = ast.literal_eval(input)
         | 
| 97 | 
            +
                write_db(eval_result)
         | 
| 98 | 
            +
             | 
| 99 | 
             
                return (
         | 
| 100 | 
            +
                    str(eval_result["city"]),
         | 
| 101 | 
            +
                    str(eval_result["distinct"]),
         | 
| 102 | 
            +
                    str(eval_result["neighbourhood"]),
         | 
| 103 | 
            +
                    str(eval_result["street"]),
         | 
| 104 | 
            +
                    str(eval_result["address"]),
         | 
| 105 | 
            +
                    str(eval_result["tel"]),
         | 
| 106 | 
            +
                    str(eval_result["name_surname"]),
         | 
| 107 | 
             
                    str(eval_result["no"]),
         | 
|  | |
|  | |
| 108 | 
             
                )
         | 
| 109 |  | 
| 110 |  | 
| 111 | 
            +
            def openai_response(ocr_input):
         | 
| 112 | 
            +
                prompt = f"""Tabular Data Extraction You are a highly intelligent and accurate tabular data extractor from 
         | 
| 113 | 
            +
                        plain text input and especially from emergency text that carries address information, your inputs can be text 
         | 
| 114 | 
            +
                        of arbitrary size, but the output should be in [{{'tabular': {{'entity_type': 'entity'}} }}] JSON format Force it 
         | 
| 115 | 
            +
                        to only extract keys that are shared as an example in the examples section, if a key value is not found in the 
         | 
| 116 | 
            +
                        text input, then it should be ignored. Have only city, distinct, neighbourhood, 
         | 
| 117 | 
            +
                        street, no, tel, name_surname, address Examples: Input: Deprem sırasında evimizde yer alan adresimiz: İstanbul, 
         | 
| 118 | 
            +
                        Beşiktaş, Yıldız Mahallesi, Cumhuriyet Caddesi No: 35, cep telefonu numaram 5551231256, adim Ahmet Yilmaz 
         | 
| 119 | 
            +
                        Output: {{'city': 'İstanbul', 'distinct': 'Beşiktaş', 'neighbourhood': 'Yıldız Mahallesi', 'street': 'Cumhuriyet Caddesi', 'no': '35', 'tel': '5551231256', 'name_surname': 'Ahmet Yılmaz', 'address': 'İstanbul, Beşiktaş, Yıldız Mahallesi, Cumhuriyet Caddesi No: 35'}}
         | 
| 120 | 
            +
                        Input: {ocr_input}
         | 
| 121 | 
            +
                        Output:
         | 
| 122 | 
            +
                    """
         | 
| 123 |  | 
| 124 | 
            +
                response = openai.Completion.create(
         | 
| 125 | 
            +
                    model="text-davinci-003",
         | 
| 126 | 
            +
                    prompt=prompt,
         | 
| 127 | 
            +
                    temperature=0,
         | 
| 128 | 
            +
                    max_tokens=300,
         | 
| 129 | 
            +
                    top_p=1,
         | 
| 130 | 
            +
                    frequency_penalty=0.0,
         | 
| 131 | 
            +
                    presence_penalty=0.0,
         | 
| 132 | 
            +
                    stop=["\n"],
         | 
| 133 | 
            +
                )
         | 
| 134 | 
            +
                resp = response["choices"][0]["text"]
         | 
| 135 | 
            +
                print(resp)
         | 
| 136 | 
            +
                resp = eval(resp.replace("'{", "{").replace("}'", "}"))
         | 
| 137 | 
             
                resp["input"] = ocr_input
         | 
| 138 | 
            +
                dict_keys = [
         | 
| 139 | 
            +
                    "city",
         | 
| 140 | 
            +
                    "distinct",
         | 
| 141 | 
            +
                    "neighbourhood",
         | 
| 142 | 
            +
                    "street",
         | 
| 143 | 
            +
                    "no",
         | 
| 144 | 
            +
                    "tel",
         | 
| 145 | 
            +
                    "name_surname",
         | 
| 146 | 
            +
                    "address",
         | 
| 147 | 
            +
                    "input",
         | 
| 148 | 
            +
                ]
         | 
| 149 | 
             
                for key in dict_keys:
         | 
| 150 | 
             
                    if key not in resp.keys():
         | 
| 151 | 
             
                        resp[key] = ""
         | 
| 152 | 
             
                return resp
         | 
| 153 |  | 
| 154 |  | 
|  | |
| 155 | 
             
            with gr.Blocks() as demo:
         | 
| 156 | 
             
                gr.Markdown(
         | 
| 157 | 
             
                    """
         | 
|  | |
| 162 | 
             
                    "Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın."
         | 
| 163 | 
             
                )
         | 
| 164 | 
             
                with gr.Row():
         | 
| 165 | 
            +
                    img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇")
         | 
| 166 | 
            +
                    ocr_result = gr.Textbox(label="Metin yükleyin 👇 ")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 167 | 
             
                open_api_text = gr.Textbox(label="Tam Adres")
         | 
| 168 | 
            +
                submit_button = gr.Button(label="Yükle")
         | 
| 169 | 
             
                with gr.Column():
         | 
| 170 | 
             
                    with gr.Row():
         | 
| 171 | 
            +
                        city = gr.Textbox(label="İl")
         | 
| 172 | 
            +
                        distinct = gr.Textbox(label="İlçe")
         | 
| 173 | 
             
                    with gr.Row():
         | 
| 174 | 
            +
                        neighbourhood = gr.Textbox(label="Mahalle")
         | 
| 175 | 
            +
                        street = gr.Textbox(label="Sokak/Cadde/Bulvar")
         | 
|  | |
|  | |
|  | |
|  | |
| 176 | 
             
                    with gr.Row():
         | 
| 177 | 
            +
                        tel = gr.Textbox(label="Telefon")
         | 
| 178 | 
             
                    with gr.Row():
         | 
| 179 | 
            +
                        name_surname = gr.Textbox(label="İsim Soyisim")
         | 
| 180 | 
            +
                        address = gr.Textbox(label="Adres")
         | 
|  | |
|  | |
| 181 | 
             
                    with gr.Row():
         | 
| 182 | 
            +
                        no = gr.Textbox(label="Kapı No")
         | 
| 183 |  | 
| 184 | 
            +
                submit_button.click(
         | 
| 185 | 
             
                    get_parsed_address,
         | 
| 186 | 
             
                    inputs=img_area,
         | 
| 187 | 
             
                    outputs=open_api_text,
         | 
| 188 | 
            +
                    api_name="upload_image",
         | 
| 189 | 
             
                )
         | 
| 190 |  | 
| 191 | 
            +
                ocr_result.change(
         | 
| 192 | 
            +
                    openai_response, ocr_result, open_api_text, api_name="upload-text"
         | 
| 193 | 
             
                )
         | 
| 194 |  | 
|  | |
| 195 | 
             
                open_api_text.change(
         | 
| 196 | 
             
                    text_dict,
         | 
| 197 | 
             
                    open_api_text,
         | 
| 198 | 
            +
                    [city, distinct, neighbourhood, street, address, tel, name_surname, no],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 199 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 200 |  | 
| 201 |  | 
| 202 | 
             
            if __name__ == "__main__":
         | 
| 203 | 
            +
                demo.launch()
         | 
    	
        db_utils.py
    DELETED
    
    | @@ -1,41 +0,0 @@ | |
| 1 | 
            -
            from deta import Deta  # Import Deta
         | 
| 2 | 
            -
            from pprint import pprint
         | 
| 3 | 
            -
            import os
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            deta_key = os.getenv("DETA_KEY")
         | 
| 6 | 
            -
            deta = Deta(deta_key)
         | 
| 7 | 
            -
            db = deta.Base("deprem-ocr")
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def get_users_by_city(city_name, limit=10):
         | 
| 11 | 
            -
             | 
| 12 | 
            -
                user = db.fetch({"city": city_name.capitalize()}, limit=limit).items
         | 
| 13 | 
            -
                return user
         | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            def get_all():
         | 
| 17 | 
            -
                res = db.fetch()
         | 
| 18 | 
            -
                all_items = res.items
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                # fetch until last is 'None'
         | 
| 21 | 
            -
                while res.last:
         | 
| 22 | 
            -
                    res = db.fetch(last=res.last)
         | 
| 23 | 
            -
                    all_items += res.items
         | 
| 24 | 
            -
                return all_items
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def write_db(data_dict):
         | 
| 28 | 
            -
                # 2) initialize with a project key
         | 
| 29 | 
            -
                deta_key = os.getenv("DETA_KEY")
         | 
| 30 | 
            -
                deta = Deta(deta_key)
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                # 3) create and use as many DBs as you want!
         | 
| 33 | 
            -
                users = deta.Base("deprem-ocr")
         | 
| 34 | 
            -
                users.insert(data_dict)
         | 
| 35 | 
            -
                print("Pushed to db")
         | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
            def get_latest_row(last):
         | 
| 39 | 
            -
                all_items = get_all()
         | 
| 40 | 
            -
                latest_items = all_items[-last:]
         | 
| 41 | 
            -
                return latest_items
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        .gitignore → ocr/.gitignore
    RENAMED
    
    | @@ -20,6 +20,7 @@ parts/ | |
| 20 | 
             
            sdist/
         | 
| 21 | 
             
            var/
         | 
| 22 | 
             
            wheels/
         | 
|  | |
| 23 | 
             
            share/python-wheels/
         | 
| 24 | 
             
            *.egg-info/
         | 
| 25 | 
             
            .installed.cfg
         | 
| @@ -49,7 +50,6 @@ coverage.xml | |
| 49 | 
             
            *.py,cover
         | 
| 50 | 
             
            .hypothesis/
         | 
| 51 | 
             
            .pytest_cache/
         | 
| 52 | 
            -
            cover/
         | 
| 53 |  | 
| 54 | 
             
            # Translations
         | 
| 55 | 
             
            *.mo
         | 
| @@ -72,7 +72,6 @@ instance/ | |
| 72 | 
             
            docs/_build/
         | 
| 73 |  | 
| 74 | 
             
            # PyBuilder
         | 
| 75 | 
            -
            .pybuilder/
         | 
| 76 | 
             
            target/
         | 
| 77 |  | 
| 78 | 
             
            # Jupyter Notebook
         | 
| @@ -83,9 +82,7 @@ profile_default/ | |
| 83 | 
             
            ipython_config.py
         | 
| 84 |  | 
| 85 | 
             
            # pyenv
         | 
| 86 | 
            -
             | 
| 87 | 
            -
            #   intended to run in multiple environments; otherwise, check them in:
         | 
| 88 | 
            -
            # .python-version
         | 
| 89 |  | 
| 90 | 
             
            # pipenv
         | 
| 91 | 
             
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| @@ -94,22 +91,7 @@ ipython_config.py | |
| 94 | 
             
            #   install all needed dependencies.
         | 
| 95 | 
             
            #Pipfile.lock
         | 
| 96 |  | 
| 97 | 
            -
            #  | 
| 98 | 
            -
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         | 
| 99 | 
            -
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         | 
| 100 | 
            -
            #   commonly ignored for libraries.
         | 
| 101 | 
            -
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         | 
| 102 | 
            -
            #poetry.lock
         | 
| 103 | 
            -
             | 
| 104 | 
            -
            # pdm
         | 
| 105 | 
            -
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         | 
| 106 | 
            -
            #pdm.lock
         | 
| 107 | 
            -
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         | 
| 108 | 
            -
            #   in version control.
         | 
| 109 | 
            -
            #   https://pdm.fming.dev/#use-with-ide
         | 
| 110 | 
            -
            .pdm.toml
         | 
| 111 | 
            -
             | 
| 112 | 
            -
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         | 
| 113 | 
             
            __pypackages__/
         | 
| 114 |  | 
| 115 | 
             
            # Celery stuff
         | 
| @@ -145,18 +127,3 @@ dmypy.json | |
| 145 |  | 
| 146 | 
             
            # Pyre type checker
         | 
| 147 | 
             
            .pyre/
         | 
| 148 | 
            -
             | 
| 149 | 
            -
            # pytype static type analyzer
         | 
| 150 | 
            -
            .pytype/
         | 
| 151 | 
            -
             | 
| 152 | 
            -
            # Cython debug symbols
         | 
| 153 | 
            -
            cython_debug/
         | 
| 154 | 
            -
             | 
| 155 | 
            -
            # PyCharm
         | 
| 156 | 
            -
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         | 
| 157 | 
            -
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         | 
| 158 | 
            -
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         | 
| 159 | 
            -
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         | 
| 160 | 
            -
            #.idea/
         | 
| 161 | 
            -
             | 
| 162 | 
            -
            .DS_Store
         | 
|  | |
| 20 | 
             
            sdist/
         | 
| 21 | 
             
            var/
         | 
| 22 | 
             
            wheels/
         | 
| 23 | 
            +
            pip-wheel-metadata/
         | 
| 24 | 
             
            share/python-wheels/
         | 
| 25 | 
             
            *.egg-info/
         | 
| 26 | 
             
            .installed.cfg
         | 
|  | |
| 50 | 
             
            *.py,cover
         | 
| 51 | 
             
            .hypothesis/
         | 
| 52 | 
             
            .pytest_cache/
         | 
|  | |
| 53 |  | 
| 54 | 
             
            # Translations
         | 
| 55 | 
             
            *.mo
         | 
|  | |
| 72 | 
             
            docs/_build/
         | 
| 73 |  | 
| 74 | 
             
            # PyBuilder
         | 
|  | |
| 75 | 
             
            target/
         | 
| 76 |  | 
| 77 | 
             
            # Jupyter Notebook
         | 
|  | |
| 82 | 
             
            ipython_config.py
         | 
| 83 |  | 
| 84 | 
             
            # pyenv
         | 
| 85 | 
            +
            .python-version
         | 
|  | |
|  | |
| 86 |  | 
| 87 | 
             
            # pipenv
         | 
| 88 | 
             
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
|  | |
| 91 | 
             
            #   install all needed dependencies.
         | 
| 92 | 
             
            #Pipfile.lock
         | 
| 93 |  | 
| 94 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 95 | 
             
            __pypackages__/
         | 
| 96 |  | 
| 97 | 
             
            # Celery stuff
         | 
|  | |
| 127 |  | 
| 128 | 
             
            # Pyre type checker
         | 
| 129 | 
             
            .pyre/
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        ocr/README.md
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            # deprem-ocr
         | 
    	
        ocr/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ocr/ch_PP-OCRv3_det_infer/inference.pdiparams
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:7e9518c6ab706fe87842a8de1c098f990e67f9212b67c9ef8bc4bca6dc17b91a
         | 
| 3 | 
            +
            size 2377917
         | 
    	
        ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info
    ADDED
    
    | Binary file (26.4 kB). View file | 
|  | 
    	
        ocr/ch_PP-OCRv3_det_infer/inference.pdmodel
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:74b075e6cfbc8206dab2eee86a6a8bd015a7be612b2bf6d1a1ef878d31df84f7
         | 
| 3 | 
            +
            size 1413260
         | 
    	
        ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:d99d4279f7c64471b8f0be426ee09a46c0f1ecb344406bf0bb9571f670e8d0c7
         | 
| 3 | 
            +
            size 10614098
         | 
    	
        ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info
    ADDED
    
    | Binary file (22 kB). View file | 
|  | 
    	
        ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b9beb0b9520d34bde2a0f92581ed64db7e4d6c76abead8b859189ea72db9ee20
         | 
| 3 | 
            +
            size 1266415
         | 
    	
        ocr/detector.py
    ADDED
    
    | @@ -0,0 +1,248 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            __dir__ = os.path.dirname(os.path.abspath(__file__))
         | 
| 5 | 
            +
            sys.path.append(__dir__)
         | 
| 6 | 
            +
            sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            os.environ["FLAGS_allocator_strategy"] = "auto_growth"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            import sys
         | 
| 12 | 
            +
            import time
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import cv2
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import utility
         | 
| 18 | 
            +
            from postprocess import build_post_process
         | 
| 19 | 
            +
            from ppocr.data import create_operators, transform
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class TextDetector(object):
         | 
| 23 | 
            +
                def __init__(self, args):
         | 
| 24 | 
            +
                    self.args = args
         | 
| 25 | 
            +
                    self.det_algorithm = args.det_algorithm
         | 
| 26 | 
            +
                    self.use_onnx = args.use_onnx
         | 
| 27 | 
            +
                    pre_process_list = [
         | 
| 28 | 
            +
                        {
         | 
| 29 | 
            +
                            "DetResizeForTest": {
         | 
| 30 | 
            +
                                "limit_side_len": args.det_limit_side_len,
         | 
| 31 | 
            +
                                "limit_type": args.det_limit_type,
         | 
| 32 | 
            +
                            }
         | 
| 33 | 
            +
                        },
         | 
| 34 | 
            +
                        {
         | 
| 35 | 
            +
                            "NormalizeImage": {
         | 
| 36 | 
            +
                                "std": [0.229, 0.224, 0.225],
         | 
| 37 | 
            +
                                "mean": [0.485, 0.456, 0.406],
         | 
| 38 | 
            +
                                "scale": "1./255.",
         | 
| 39 | 
            +
                                "order": "hwc",
         | 
| 40 | 
            +
                            }
         | 
| 41 | 
            +
                        },
         | 
| 42 | 
            +
                        {"ToCHWImage": None},
         | 
| 43 | 
            +
                        {"KeepKeys": {"keep_keys": ["image", "shape"]}},
         | 
| 44 | 
            +
                    ]
         | 
| 45 | 
            +
                    postprocess_params = {}
         | 
| 46 | 
            +
                    if self.det_algorithm == "DB":
         | 
| 47 | 
            +
                        postprocess_params["name"] = "DBPostProcess"
         | 
| 48 | 
            +
                        postprocess_params["thresh"] = args.det_db_thresh
         | 
| 49 | 
            +
                        postprocess_params["box_thresh"] = args.det_db_box_thresh
         | 
| 50 | 
            +
                        postprocess_params["max_candidates"] = 1000
         | 
| 51 | 
            +
                        postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
         | 
| 52 | 
            +
                        postprocess_params["use_dilation"] = args.use_dilation
         | 
| 53 | 
            +
                        postprocess_params["score_mode"] = args.det_db_score_mode
         | 
| 54 | 
            +
                    elif self.det_algorithm == "EAST":
         | 
| 55 | 
            +
                        postprocess_params["name"] = "EASTPostProcess"
         | 
| 56 | 
            +
                        postprocess_params["score_thresh"] = args.det_east_score_thresh
         | 
| 57 | 
            +
                        postprocess_params["cover_thresh"] = args.det_east_cover_thresh
         | 
| 58 | 
            +
                        postprocess_params["nms_thresh"] = args.det_east_nms_thresh
         | 
| 59 | 
            +
                    elif self.det_algorithm == "SAST":
         | 
| 60 | 
            +
                        pre_process_list[0] = {
         | 
| 61 | 
            +
                            "DetResizeForTest": {"resize_long": args.det_limit_side_len}
         | 
| 62 | 
            +
                        }
         | 
| 63 | 
            +
                        postprocess_params["name"] = "SASTPostProcess"
         | 
| 64 | 
            +
                        postprocess_params["score_thresh"] = args.det_sast_score_thresh
         | 
| 65 | 
            +
                        postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
         | 
| 66 | 
            +
                        self.det_sast_polygon = args.det_sast_polygon
         | 
| 67 | 
            +
                        if self.det_sast_polygon:
         | 
| 68 | 
            +
                            postprocess_params["sample_pts_num"] = 6
         | 
| 69 | 
            +
                            postprocess_params["expand_scale"] = 1.2
         | 
| 70 | 
            +
                            postprocess_params["shrink_ratio_of_width"] = 0.2
         | 
| 71 | 
            +
                        else:
         | 
| 72 | 
            +
                            postprocess_params["sample_pts_num"] = 2
         | 
| 73 | 
            +
                            postprocess_params["expand_scale"] = 1.0
         | 
| 74 | 
            +
                            postprocess_params["shrink_ratio_of_width"] = 0.3
         | 
| 75 | 
            +
                    elif self.det_algorithm == "PSE":
         | 
| 76 | 
            +
                        postprocess_params["name"] = "PSEPostProcess"
         | 
| 77 | 
            +
                        postprocess_params["thresh"] = args.det_pse_thresh
         | 
| 78 | 
            +
                        postprocess_params["box_thresh"] = args.det_pse_box_thresh
         | 
| 79 | 
            +
                        postprocess_params["min_area"] = args.det_pse_min_area
         | 
| 80 | 
            +
                        postprocess_params["box_type"] = args.det_pse_box_type
         | 
| 81 | 
            +
                        postprocess_params["scale"] = args.det_pse_scale
         | 
| 82 | 
            +
                        self.det_pse_box_type = args.det_pse_box_type
         | 
| 83 | 
            +
                    elif self.det_algorithm == "FCE":
         | 
| 84 | 
            +
                        pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}}
         | 
| 85 | 
            +
                        postprocess_params["name"] = "FCEPostProcess"
         | 
| 86 | 
            +
                        postprocess_params["scales"] = args.scales
         | 
| 87 | 
            +
                        postprocess_params["alpha"] = args.alpha
         | 
| 88 | 
            +
                        postprocess_params["beta"] = args.beta
         | 
| 89 | 
            +
                        postprocess_params["fourier_degree"] = args.fourier_degree
         | 
| 90 | 
            +
                        postprocess_params["box_type"] = args.det_fce_box_type
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.preprocess_op = create_operators(pre_process_list)
         | 
| 93 | 
            +
                    self.postprocess_op = build_post_process(postprocess_params)
         | 
| 94 | 
            +
                    (
         | 
| 95 | 
            +
                        self.predictor,
         | 
| 96 | 
            +
                        self.input_tensor,
         | 
| 97 | 
            +
                        self.output_tensors,
         | 
| 98 | 
            +
                        self.config,
         | 
| 99 | 
            +
                    ) = utility.create_predictor(args, "det")
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    if self.use_onnx:
         | 
| 102 | 
            +
                        img_h, img_w = self.input_tensor.shape[2:]
         | 
| 103 | 
            +
                        if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
         | 
| 104 | 
            +
                            pre_process_list[0] = {
         | 
| 105 | 
            +
                                "DetResizeForTest": {"image_shape": [img_h, img_w]}
         | 
| 106 | 
            +
                            }
         | 
| 107 | 
            +
                    self.preprocess_op = create_operators(pre_process_list)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def order_points_clockwise(self, pts):
         | 
| 110 | 
            +
                    rect = np.zeros((4, 2), dtype="float32")
         | 
| 111 | 
            +
                    s = pts.sum(axis=1)
         | 
| 112 | 
            +
                    rect[0] = pts[np.argmin(s)]
         | 
| 113 | 
            +
                    rect[2] = pts[np.argmax(s)]
         | 
| 114 | 
            +
                    diff = np.diff(pts, axis=1)
         | 
| 115 | 
            +
                    rect[1] = pts[np.argmin(diff)]
         | 
| 116 | 
            +
                    rect[3] = pts[np.argmax(diff)]
         | 
| 117 | 
            +
                    return rect
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def clip_det_res(self, points, img_height, img_width):
         | 
| 120 | 
            +
                    for pno in range(points.shape[0]):
         | 
| 121 | 
            +
                        points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
         | 
| 122 | 
            +
                        points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
         | 
| 123 | 
            +
                    return points
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def filter_tag_det_res(self, dt_boxes, image_shape):
         | 
| 126 | 
            +
                    img_height, img_width = image_shape[0:2]
         | 
| 127 | 
            +
                    dt_boxes_new = []
         | 
| 128 | 
            +
                    for box in dt_boxes:
         | 
| 129 | 
            +
                        box = self.order_points_clockwise(box)
         | 
| 130 | 
            +
                        box = self.clip_det_res(box, img_height, img_width)
         | 
| 131 | 
            +
                        rect_width = int(np.linalg.norm(box[0] - box[1]))
         | 
| 132 | 
            +
                        rect_height = int(np.linalg.norm(box[0] - box[3]))
         | 
| 133 | 
            +
                        if rect_width <= 3 or rect_height <= 3:
         | 
| 134 | 
            +
                            continue
         | 
| 135 | 
            +
                        dt_boxes_new.append(box)
         | 
| 136 | 
            +
                    dt_boxes = np.array(dt_boxes_new)
         | 
| 137 | 
            +
                    return dt_boxes
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
         | 
| 140 | 
            +
                    img_height, img_width = image_shape[0:2]
         | 
| 141 | 
            +
                    dt_boxes_new = []
         | 
| 142 | 
            +
                    for box in dt_boxes:
         | 
| 143 | 
            +
                        box = self.clip_det_res(box, img_height, img_width)
         | 
| 144 | 
            +
                        dt_boxes_new.append(box)
         | 
| 145 | 
            +
                    dt_boxes = np.array(dt_boxes_new)
         | 
| 146 | 
            +
                    return dt_boxes
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def __call__(self, img):
         | 
| 149 | 
            +
                    ori_im = img.copy()
         | 
| 150 | 
            +
                    data = {"image": img}
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    st = time.time()
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    data = transform(data, self.preprocess_op)
         | 
| 155 | 
            +
                    img, shape_list = data
         | 
| 156 | 
            +
                    if img is None:
         | 
| 157 | 
            +
                        return None, 0
         | 
| 158 | 
            +
                    img = np.expand_dims(img, axis=0)
         | 
| 159 | 
            +
                    shape_list = np.expand_dims(shape_list, axis=0)
         | 
| 160 | 
            +
                    img = img.copy()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if self.use_onnx:
         | 
| 163 | 
            +
                        input_dict = {}
         | 
| 164 | 
            +
                        input_dict[self.input_tensor.name] = img
         | 
| 165 | 
            +
                        outputs = self.predictor.run(self.output_tensors, input_dict)
         | 
| 166 | 
            +
                    else:
         | 
| 167 | 
            +
                        self.input_tensor.copy_from_cpu(img)
         | 
| 168 | 
            +
                        self.predictor.run()
         | 
| 169 | 
            +
                        outputs = []
         | 
| 170 | 
            +
                        for output_tensor in self.output_tensors:
         | 
| 171 | 
            +
                            output = output_tensor.copy_to_cpu()
         | 
| 172 | 
            +
                            outputs.append(output)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    preds = {}
         | 
| 175 | 
            +
                    if self.det_algorithm == "EAST":
         | 
| 176 | 
            +
                        preds["f_geo"] = outputs[0]
         | 
| 177 | 
            +
                        preds["f_score"] = outputs[1]
         | 
| 178 | 
            +
                    elif self.det_algorithm == "SAST":
         | 
| 179 | 
            +
                        preds["f_border"] = outputs[0]
         | 
| 180 | 
            +
                        preds["f_score"] = outputs[1]
         | 
| 181 | 
            +
                        preds["f_tco"] = outputs[2]
         | 
| 182 | 
            +
                        preds["f_tvo"] = outputs[3]
         | 
| 183 | 
            +
                    elif self.det_algorithm in ["DB", "PSE"]:
         | 
| 184 | 
            +
                        preds["maps"] = outputs[0]
         | 
| 185 | 
            +
                    elif self.det_algorithm == "FCE":
         | 
| 186 | 
            +
                        for i, output in enumerate(outputs):
         | 
| 187 | 
            +
                            preds["level_{}".format(i)] = output
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        raise NotImplementedError
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # self.predictor.try_shrink_memory()
         | 
| 192 | 
            +
                    post_result = self.postprocess_op(preds, shape_list)
         | 
| 193 | 
            +
                    dt_boxes = post_result[0]["points"]
         | 
| 194 | 
            +
                    if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
         | 
| 195 | 
            +
                        self.det_algorithm in ["PSE", "FCE"]
         | 
| 196 | 
            +
                        and self.postprocess_op.box_type == "poly"
         | 
| 197 | 
            +
                    ):
         | 
| 198 | 
            +
                        dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
         | 
| 199 | 
            +
                    else:
         | 
| 200 | 
            +
                        dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    et = time.time()
         | 
| 203 | 
            +
                    return dt_boxes, et - st
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            if __name__ == "__main__":
         | 
| 207 | 
            +
                args = utility.parse_args()
         | 
| 208 | 
            +
                image_file_list = ["images/y.png"]
         | 
| 209 | 
            +
                text_detector = TextDetector(args)
         | 
| 210 | 
            +
                count = 0
         | 
| 211 | 
            +
                total_time = 0
         | 
| 212 | 
            +
                draw_img_save = "./inference_results"
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                if args.warmup:
         | 
| 215 | 
            +
                    img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
         | 
| 216 | 
            +
                    for i in range(2):
         | 
| 217 | 
            +
                        res = text_detector(img)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                if not os.path.exists(draw_img_save):
         | 
| 220 | 
            +
                    os.makedirs(draw_img_save)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                save_results = []
         | 
| 223 | 
            +
                for image_file in image_file_list:
         | 
| 224 | 
            +
                    img = cv2.imread(image_file)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    for _ in range(10):
         | 
| 227 | 
            +
                        st = time.time()
         | 
| 228 | 
            +
                        dt_boxes, _ = text_detector(img)
         | 
| 229 | 
            +
                        elapse = time.time() - st
         | 
| 230 | 
            +
                        print(elapse * 1000)
         | 
| 231 | 
            +
                    if count > 0:
         | 
| 232 | 
            +
                        total_time += elapse
         | 
| 233 | 
            +
                    count += 1
         | 
| 234 | 
            +
                    save_pred = (
         | 
| 235 | 
            +
                        os.path.basename(image_file)
         | 
| 236 | 
            +
                        + "\t"
         | 
| 237 | 
            +
                        + str(json.dumps([x.tolist() for x in dt_boxes]))
         | 
| 238 | 
            +
                        + "\n"
         | 
| 239 | 
            +
                    )
         | 
| 240 | 
            +
                    save_results.append(save_pred)
         | 
| 241 | 
            +
                    src_im = utility.draw_text_det_res(dt_boxes, image_file)
         | 
| 242 | 
            +
                    img_name_pure = os.path.split(image_file)[-1]
         | 
| 243 | 
            +
                    img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure))
         | 
| 244 | 
            +
                    cv2.imwrite(img_path, src_im)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                with open(os.path.join(draw_img_save, "det_results.txt"), "w") as f:
         | 
| 247 | 
            +
                    f.writelines(save_results)
         | 
| 248 | 
            +
                    f.close()
         | 
    	
        ocr/inference.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import time
         | 
| 6 | 
            +
            import requests
         | 
| 7 | 
            +
            from io import BytesIO
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import utility
         | 
| 10 | 
            +
            from detector import *
         | 
| 11 | 
            +
            from recognizer import *
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Global Detector and Recognizer
         | 
| 14 | 
            +
            args = utility.parse_args()
         | 
| 15 | 
            +
            text_recognizer = TextRecognizer(args)
         | 
| 16 | 
            +
            text_detector = TextDetector(args)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def apply_ocr(img):
         | 
| 20 | 
            +
                # Detect text regions
         | 
| 21 | 
            +
                dt_boxes, _ = text_detector(img)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                boxes = []
         | 
| 24 | 
            +
                for box in dt_boxes:
         | 
| 25 | 
            +
                    p1, p2, p3, p4 = box
         | 
| 26 | 
            +
                    x1 = min(p1[0], p2[0], p3[0], p4[0])
         | 
| 27 | 
            +
                    y1 = min(p1[1], p2[1], p3[1], p4[1])
         | 
| 28 | 
            +
                    x2 = max(p1[0], p2[0], p3[0], p4[0])
         | 
| 29 | 
            +
                    y2 = max(p1[1], p2[1], p3[1], p4[1])
         | 
| 30 | 
            +
                    boxes.append([x1, y1, x2, y2])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # Recognize text
         | 
| 33 | 
            +
                img_list = []
         | 
| 34 | 
            +
                for i in range(len(boxes)):
         | 
| 35 | 
            +
                    x1, y1, x2, y2 = map(int, boxes[i])
         | 
| 36 | 
            +
                    img_list.append(img.copy()[y1:y2, x1:x2])
         | 
| 37 | 
            +
                img_list.reverse()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                rec_res, _ = text_recognizer(img_list)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Postprocess
         | 
| 42 | 
            +
                total_text = ""
         | 
| 43 | 
            +
                table = dict()
         | 
| 44 | 
            +
                for i in range(len(rec_res)):
         | 
| 45 | 
            +
                    table[i] = {
         | 
| 46 | 
            +
                        "text": rec_res[i][0],
         | 
| 47 | 
            +
                    }
         | 
| 48 | 
            +
                    total_text += rec_res[i][0] + " "
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                total_text = total_text.strip()
         | 
| 51 | 
            +
                return total_text
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def main():
         | 
| 55 | 
            +
                image_url = "https://i.ibb.co/kQvHGjj/aewrg.png"
         | 
| 56 | 
            +
                response = requests.get(image_url)
         | 
| 57 | 
            +
                img = np.array(Image.open(BytesIO(response.content)).convert("RGB"))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                t0 = time.time()
         | 
| 60 | 
            +
                epoch = 1
         | 
| 61 | 
            +
                for _ in range(epoch):
         | 
| 62 | 
            +
                    ocr_text = apply_ocr(img)
         | 
| 63 | 
            +
                print("Elapsed time:", (time.time() - t0) * 1000 / epoch, "ms")
         | 
| 64 | 
            +
                print("Output:", ocr_text)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            if __name__ == "__main__":
         | 
| 68 | 
            +
                main()
         | 
    	
        ocr/postprocess/__init__.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            __all__ = ["build_post_process"]
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .cls_postprocess import ClsPostProcess
         | 
| 8 | 
            +
            from .db_postprocess import DBPostProcess, DistillationDBPostProcess
         | 
| 9 | 
            +
            from .east_postprocess import EASTPostProcess
         | 
| 10 | 
            +
            from .fce_postprocess import FCEPostProcess
         | 
| 11 | 
            +
            from .pg_postprocess import PGPostProcess
         | 
| 12 | 
            +
            from .rec_postprocess import (
         | 
| 13 | 
            +
                AttnLabelDecode,
         | 
| 14 | 
            +
                CTCLabelDecode,
         | 
| 15 | 
            +
                DistillationCTCLabelDecode,
         | 
| 16 | 
            +
                NRTRLabelDecode,
         | 
| 17 | 
            +
                PRENLabelDecode,
         | 
| 18 | 
            +
                SARLabelDecode,
         | 
| 19 | 
            +
                SEEDLabelDecode,
         | 
| 20 | 
            +
                SRNLabelDecode,
         | 
| 21 | 
            +
                TableLabelDecode,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            from .sast_postprocess import SASTPostProcess
         | 
| 24 | 
            +
            from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
         | 
| 25 | 
            +
            from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def build_post_process(config, global_config=None):
         | 
| 29 | 
            +
                support_dict = [
         | 
| 30 | 
            +
                    "DBPostProcess",
         | 
| 31 | 
            +
                    "EASTPostProcess",
         | 
| 32 | 
            +
                    "SASTPostProcess",
         | 
| 33 | 
            +
                    "FCEPostProcess",
         | 
| 34 | 
            +
                    "CTCLabelDecode",
         | 
| 35 | 
            +
                    "AttnLabelDecode",
         | 
| 36 | 
            +
                    "ClsPostProcess",
         | 
| 37 | 
            +
                    "SRNLabelDecode",
         | 
| 38 | 
            +
                    "PGPostProcess",
         | 
| 39 | 
            +
                    "DistillationCTCLabelDecode",
         | 
| 40 | 
            +
                    "TableLabelDecode",
         | 
| 41 | 
            +
                    "DistillationDBPostProcess",
         | 
| 42 | 
            +
                    "NRTRLabelDecode",
         | 
| 43 | 
            +
                    "SARLabelDecode",
         | 
| 44 | 
            +
                    "SEEDLabelDecode",
         | 
| 45 | 
            +
                    "VQASerTokenLayoutLMPostProcess",
         | 
| 46 | 
            +
                    "VQAReTokenLayoutLMPostProcess",
         | 
| 47 | 
            +
                    "PRENLabelDecode",
         | 
| 48 | 
            +
                    "DistillationSARLabelDecode",
         | 
| 49 | 
            +
                ]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if config["name"] == "PSEPostProcess":
         | 
| 52 | 
            +
                    from .pse_postprocess import PSEPostProcess
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    support_dict.append("PSEPostProcess")
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                config = copy.deepcopy(config)
         | 
| 57 | 
            +
                module_name = config.pop("name")
         | 
| 58 | 
            +
                if module_name == "None":
         | 
| 59 | 
            +
                    return
         | 
| 60 | 
            +
                if global_config is not None:
         | 
| 61 | 
            +
                    config.update(global_config)
         | 
| 62 | 
            +
                assert module_name in support_dict, Exception(
         | 
| 63 | 
            +
                    "post process only support {}".format(support_dict)
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
                module_class = eval(module_name)(**config)
         | 
| 66 | 
            +
                return module_class
         | 
    	
        ocr/postprocess/cls_postprocess.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import paddle
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class ClsPostProcess(object):
         | 
| 5 | 
            +
                """Convert between text-label and text-index"""
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def __init__(self, label_list=None, key=None, **kwargs):
         | 
| 8 | 
            +
                    super(ClsPostProcess, self).__init__()
         | 
| 9 | 
            +
                    self.label_list = label_list
         | 
| 10 | 
            +
                    self.key = key
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 13 | 
            +
                    if self.key is not None:
         | 
| 14 | 
            +
                        preds = preds[self.key]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    label_list = self.label_list
         | 
| 17 | 
            +
                    if label_list is None:
         | 
| 18 | 
            +
                        label_list = {idx: idx for idx in range(preds.shape[-1])}
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    if isinstance(preds, paddle.Tensor):
         | 
| 21 | 
            +
                        preds = preds.numpy()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    pred_idxs = preds.argmax(axis=1)
         | 
| 24 | 
            +
                    decode_out = [
         | 
| 25 | 
            +
                        (label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
         | 
| 26 | 
            +
                    ]
         | 
| 27 | 
            +
                    if label is None:
         | 
| 28 | 
            +
                        return decode_out
         | 
| 29 | 
            +
                    label = [(label_list[idx], 1.0) for idx in label]
         | 
| 30 | 
            +
                    return decode_out, label
         | 
    	
        ocr/postprocess/db_postprocess.py
    ADDED
    
    | @@ -0,0 +1,207 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import paddle
         | 
| 6 | 
            +
            import pyclipper
         | 
| 7 | 
            +
            from shapely.geometry import Polygon
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class DBPostProcess(object):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                The post process for Differentiable Binarization (DB).
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    thresh=0.3,
         | 
| 18 | 
            +
                    box_thresh=0.7,
         | 
| 19 | 
            +
                    max_candidates=1000,
         | 
| 20 | 
            +
                    unclip_ratio=2.0,
         | 
| 21 | 
            +
                    use_dilation=False,
         | 
| 22 | 
            +
                    score_mode="fast",
         | 
| 23 | 
            +
                    **kwargs
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    self.thresh = thresh
         | 
| 26 | 
            +
                    self.box_thresh = box_thresh
         | 
| 27 | 
            +
                    self.max_candidates = max_candidates
         | 
| 28 | 
            +
                    self.unclip_ratio = unclip_ratio
         | 
| 29 | 
            +
                    self.min_size = 3
         | 
| 30 | 
            +
                    self.score_mode = score_mode
         | 
| 31 | 
            +
                    assert score_mode in [
         | 
| 32 | 
            +
                        "slow",
         | 
| 33 | 
            +
                        "fast",
         | 
| 34 | 
            +
                    ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    _bitmap: single map with shape (1, H, W),
         | 
| 41 | 
            +
                            whose values are binarized as {0, 1}
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    bitmap = _bitmap
         | 
| 45 | 
            +
                    height, width = bitmap.shape
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    outs = cv2.findContours(
         | 
| 48 | 
            +
                        (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                    if len(outs) == 3:
         | 
| 51 | 
            +
                        img, contours, _ = outs[0], outs[1], outs[2]
         | 
| 52 | 
            +
                    elif len(outs) == 2:
         | 
| 53 | 
            +
                        contours, _ = outs[0], outs[1]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    num_contours = min(len(contours), self.max_candidates)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    boxes = []
         | 
| 58 | 
            +
                    scores = []
         | 
| 59 | 
            +
                    for index in range(num_contours):
         | 
| 60 | 
            +
                        contour = contours[index]
         | 
| 61 | 
            +
                        points, sside = self.get_mini_boxes(contour)
         | 
| 62 | 
            +
                        if sside < self.min_size:
         | 
| 63 | 
            +
                            continue
         | 
| 64 | 
            +
                        points = np.array(points)
         | 
| 65 | 
            +
                        if self.score_mode == "fast":
         | 
| 66 | 
            +
                            score = self.box_score_fast(pred, points.reshape(-1, 2))
         | 
| 67 | 
            +
                        else:
         | 
| 68 | 
            +
                            score = self.box_score_slow(pred, contour)
         | 
| 69 | 
            +
                        if self.box_thresh > score:
         | 
| 70 | 
            +
                            continue
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        box = self.unclip(points).reshape(-1, 1, 2)
         | 
| 73 | 
            +
                        box, sside = self.get_mini_boxes(box)
         | 
| 74 | 
            +
                        if sside < self.min_size + 2:
         | 
| 75 | 
            +
                            continue
         | 
| 76 | 
            +
                        box = np.array(box)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
         | 
| 79 | 
            +
                        box[:, 1] = np.clip(
         | 
| 80 | 
            +
                            np.round(box[:, 1] / height * dest_height), 0, dest_height
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                        boxes.append(box.astype(np.int16))
         | 
| 83 | 
            +
                        scores.append(score)
         | 
| 84 | 
            +
                    return np.array(boxes, dtype=np.int16), scores
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def unclip(self, box):
         | 
| 87 | 
            +
                    unclip_ratio = self.unclip_ratio
         | 
| 88 | 
            +
                    poly = Polygon(box)
         | 
| 89 | 
            +
                    distance = poly.area * unclip_ratio / poly.length
         | 
| 90 | 
            +
                    offset = pyclipper.PyclipperOffset()
         | 
| 91 | 
            +
                    offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         | 
| 92 | 
            +
                    expanded = np.array(offset.Execute(distance))
         | 
| 93 | 
            +
                    return expanded
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def get_mini_boxes(self, contour):
         | 
| 96 | 
            +
                    bounding_box = cv2.minAreaRect(contour)
         | 
| 97 | 
            +
                    points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    index_1, index_2, index_3, index_4 = 0, 1, 2, 3
         | 
| 100 | 
            +
                    if points[1][1] > points[0][1]:
         | 
| 101 | 
            +
                        index_1 = 0
         | 
| 102 | 
            +
                        index_4 = 1
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        index_1 = 1
         | 
| 105 | 
            +
                        index_4 = 0
         | 
| 106 | 
            +
                    if points[3][1] > points[2][1]:
         | 
| 107 | 
            +
                        index_2 = 2
         | 
| 108 | 
            +
                        index_3 = 3
         | 
| 109 | 
            +
                    else:
         | 
| 110 | 
            +
                        index_2 = 3
         | 
| 111 | 
            +
                        index_3 = 2
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    box = [points[index_1], points[index_2], points[index_3], points[index_4]]
         | 
| 114 | 
            +
                    return box, min(bounding_box[1])
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def box_score_fast(self, bitmap, _box):
         | 
| 117 | 
            +
                    """
         | 
| 118 | 
            +
                    box_score_fast: use bbox mean score as the mean score
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    h, w = bitmap.shape[:2]
         | 
| 121 | 
            +
                    box = _box.copy()
         | 
| 122 | 
            +
                    xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
         | 
| 123 | 
            +
                    xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
         | 
| 124 | 
            +
                    ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
         | 
| 125 | 
            +
                    ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         | 
| 128 | 
            +
                    box[:, 0] = box[:, 0] - xmin
         | 
| 129 | 
            +
                    box[:, 1] = box[:, 1] - ymin
         | 
| 130 | 
            +
                    cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
         | 
| 131 | 
            +
                    return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def box_score_slow(self, bitmap, contour):
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    box_score_slow: use polyon mean score as the mean score
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    h, w = bitmap.shape[:2]
         | 
| 138 | 
            +
                    contour = contour.copy()
         | 
| 139 | 
            +
                    contour = np.reshape(contour, (-1, 2))
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
         | 
| 142 | 
            +
                    xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
         | 
| 143 | 
            +
                    ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
         | 
| 144 | 
            +
                    ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    contour[:, 0] = contour[:, 0] - xmin
         | 
| 149 | 
            +
                    contour[:, 1] = contour[:, 1] - ymin
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
         | 
| 152 | 
            +
                    return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def __call__(self, outs_dict, shape_list):
         | 
| 155 | 
            +
                    pred = outs_dict["maps"]
         | 
| 156 | 
            +
                    if isinstance(pred, paddle.Tensor):
         | 
| 157 | 
            +
                        pred = pred.numpy()
         | 
| 158 | 
            +
                    pred = pred[:, 0, :, :]
         | 
| 159 | 
            +
                    segmentation = pred > self.thresh
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    boxes_batch = []
         | 
| 162 | 
            +
                    for batch_index in range(pred.shape[0]):
         | 
| 163 | 
            +
                        src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
         | 
| 164 | 
            +
                        if self.dilation_kernel is not None:
         | 
| 165 | 
            +
                            mask = cv2.dilate(
         | 
| 166 | 
            +
                                np.array(segmentation[batch_index]).astype(np.uint8),
         | 
| 167 | 
            +
                                self.dilation_kernel,
         | 
| 168 | 
            +
                            )
         | 
| 169 | 
            +
                        else:
         | 
| 170 | 
            +
                            mask = segmentation[batch_index]
         | 
| 171 | 
            +
                        boxes, scores = self.boxes_from_bitmap(
         | 
| 172 | 
            +
                            pred[batch_index], mask, src_w, src_h
         | 
| 173 | 
            +
                        )
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                        boxes_batch.append({"points": boxes})
         | 
| 176 | 
            +
                    return boxes_batch
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            class DistillationDBPostProcess(object):
         | 
| 180 | 
            +
                def __init__(
         | 
| 181 | 
            +
                    self,
         | 
| 182 | 
            +
                    model_name=["student"],
         | 
| 183 | 
            +
                    key=None,
         | 
| 184 | 
            +
                    thresh=0.3,
         | 
| 185 | 
            +
                    box_thresh=0.6,
         | 
| 186 | 
            +
                    max_candidates=1000,
         | 
| 187 | 
            +
                    unclip_ratio=1.5,
         | 
| 188 | 
            +
                    use_dilation=False,
         | 
| 189 | 
            +
                    score_mode="fast",
         | 
| 190 | 
            +
                    **kwargs
         | 
| 191 | 
            +
                ):
         | 
| 192 | 
            +
                    self.model_name = model_name
         | 
| 193 | 
            +
                    self.key = key
         | 
| 194 | 
            +
                    self.post_process = DBPostProcess(
         | 
| 195 | 
            +
                        thresh=thresh,
         | 
| 196 | 
            +
                        box_thresh=box_thresh,
         | 
| 197 | 
            +
                        max_candidates=max_candidates,
         | 
| 198 | 
            +
                        unclip_ratio=unclip_ratio,
         | 
| 199 | 
            +
                        use_dilation=use_dilation,
         | 
| 200 | 
            +
                        score_mode=score_mode,
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def __call__(self, predicts, shape_list):
         | 
| 204 | 
            +
                    results = {}
         | 
| 205 | 
            +
                    for k in self.model_name:
         | 
| 206 | 
            +
                        results[k] = self.post_process(predicts[k], shape_list=shape_list)
         | 
| 207 | 
            +
                    return results
         | 
    	
        ocr/postprocess/east_postprocess.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import paddle
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .locality_aware_nms import nms_locality
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class EASTPostProcess(object):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                The post process for EAST.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    self.score_thresh = score_thresh
         | 
| 18 | 
            +
                    self.cover_thresh = cover_thresh
         | 
| 19 | 
            +
                    self.nms_thresh = nms_thresh
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def restore_rectangle_quad(self, origin, geometry):
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
                    Restore rectangle from quadrangle.
         | 
| 24 | 
            +
                    """
         | 
| 25 | 
            +
                    # quad
         | 
| 26 | 
            +
                    origin_concat = np.concatenate(
         | 
| 27 | 
            +
                        (origin, origin, origin, origin), axis=1
         | 
| 28 | 
            +
                    )  # (n, 8)
         | 
| 29 | 
            +
                    pred_quads = origin_concat - geometry
         | 
| 30 | 
            +
                    pred_quads = pred_quads.reshape((-1, 4, 2))  # (n, 4, 2)
         | 
| 31 | 
            +
                    return pred_quads
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def detect(
         | 
| 34 | 
            +
                    self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    restore text boxes from score map and geo map
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    score_map = score_map[0]
         | 
| 41 | 
            +
                    geo_map = np.swapaxes(geo_map, 1, 0)
         | 
| 42 | 
            +
                    geo_map = np.swapaxes(geo_map, 1, 2)
         | 
| 43 | 
            +
                    # filter the score map
         | 
| 44 | 
            +
                    xy_text = np.argwhere(score_map > score_thresh)
         | 
| 45 | 
            +
                    if len(xy_text) == 0:
         | 
| 46 | 
            +
                        return []
         | 
| 47 | 
            +
                    # sort the text boxes via the y axis
         | 
| 48 | 
            +
                    xy_text = xy_text[np.argsort(xy_text[:, 0])]
         | 
| 49 | 
            +
                    # restore quad proposals
         | 
| 50 | 
            +
                    text_box_restored = self.restore_rectangle_quad(
         | 
| 51 | 
            +
                        xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
                    boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
         | 
| 54 | 
            +
                    boxes[:, :8] = text_box_restored.reshape((-1, 8))
         | 
| 55 | 
            +
                    boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    try:
         | 
| 58 | 
            +
                        import lanms
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
         | 
| 61 | 
            +
                    except:
         | 
| 62 | 
            +
                        print(
         | 
| 63 | 
            +
                            "you should install lanms by pip3 install lanms-nova to speed up nms_locality"
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
                        boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
         | 
| 66 | 
            +
                    if boxes.shape[0] == 0:
         | 
| 67 | 
            +
                        return []
         | 
| 68 | 
            +
                    # Here we filter some low score boxes by the average score map,
         | 
| 69 | 
            +
                    #   this is different from the orginal paper.
         | 
| 70 | 
            +
                    for i, box in enumerate(boxes):
         | 
| 71 | 
            +
                        mask = np.zeros_like(score_map, dtype=np.uint8)
         | 
| 72 | 
            +
                        cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
         | 
| 73 | 
            +
                        boxes[i, 8] = cv2.mean(score_map, mask)[0]
         | 
| 74 | 
            +
                    boxes = boxes[boxes[:, 8] > cover_thresh]
         | 
| 75 | 
            +
                    return boxes
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def sort_poly(self, p):
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    Sort polygons.
         | 
| 80 | 
            +
                    """
         | 
| 81 | 
            +
                    min_axis = np.argmin(np.sum(p, axis=1))
         | 
| 82 | 
            +
                    p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
         | 
| 83 | 
            +
                    if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
         | 
| 84 | 
            +
                        return p
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        return p[[0, 3, 2, 1]]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __call__(self, outs_dict, shape_list):
         | 
| 89 | 
            +
                    score_list = outs_dict["f_score"]
         | 
| 90 | 
            +
                    geo_list = outs_dict["f_geo"]
         | 
| 91 | 
            +
                    if isinstance(score_list, paddle.Tensor):
         | 
| 92 | 
            +
                        score_list = score_list.numpy()
         | 
| 93 | 
            +
                        geo_list = geo_list.numpy()
         | 
| 94 | 
            +
                    img_num = len(shape_list)
         | 
| 95 | 
            +
                    dt_boxes_list = []
         | 
| 96 | 
            +
                    for ino in range(img_num):
         | 
| 97 | 
            +
                        score = score_list[ino]
         | 
| 98 | 
            +
                        geo = geo_list[ino]
         | 
| 99 | 
            +
                        boxes = self.detect(
         | 
| 100 | 
            +
                            score_map=score,
         | 
| 101 | 
            +
                            geo_map=geo,
         | 
| 102 | 
            +
                            score_thresh=self.score_thresh,
         | 
| 103 | 
            +
                            cover_thresh=self.cover_thresh,
         | 
| 104 | 
            +
                            nms_thresh=self.nms_thresh,
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                        boxes_norm = []
         | 
| 107 | 
            +
                        if len(boxes) > 0:
         | 
| 108 | 
            +
                            h, w = score.shape[1:]
         | 
| 109 | 
            +
                            src_h, src_w, ratio_h, ratio_w = shape_list[ino]
         | 
| 110 | 
            +
                            boxes = boxes[:, :8].reshape((-1, 4, 2))
         | 
| 111 | 
            +
                            boxes[:, :, 0] /= ratio_w
         | 
| 112 | 
            +
                            boxes[:, :, 1] /= ratio_h
         | 
| 113 | 
            +
                            for i_box, box in enumerate(boxes):
         | 
| 114 | 
            +
                                box = self.sort_poly(box.astype(np.int32))
         | 
| 115 | 
            +
                                if (
         | 
| 116 | 
            +
                                    np.linalg.norm(box[0] - box[1]) < 5
         | 
| 117 | 
            +
                                    or np.linalg.norm(box[3] - box[0]) < 5
         | 
| 118 | 
            +
                                ):
         | 
| 119 | 
            +
                                    continue
         | 
| 120 | 
            +
                                boxes_norm.append(box)
         | 
| 121 | 
            +
                        dt_boxes_list.append({"points": np.array(boxes_norm)})
         | 
| 122 | 
            +
                    return dt_boxes_list
         | 
    	
        ocr/postprocess/extract_textpoint_fast.py
    ADDED
    
    | @@ -0,0 +1,464 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from itertools import groupby
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from skimage.morphology._skeletonize import thin
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_dict(character_dict_path):
         | 
| 11 | 
            +
                character_str = ""
         | 
| 12 | 
            +
                with open(character_dict_path, "rb") as fin:
         | 
| 13 | 
            +
                    lines = fin.readlines()
         | 
| 14 | 
            +
                    for line in lines:
         | 
| 15 | 
            +
                        line = line.decode("utf-8").strip("\n").strip("\r\n")
         | 
| 16 | 
            +
                        character_str += line
         | 
| 17 | 
            +
                    dict_character = list(character_str)
         | 
| 18 | 
            +
                return dict_character
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def softmax(logits):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                logits: N x d
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                max_value = np.max(logits, axis=1, keepdims=True)
         | 
| 26 | 
            +
                exp = np.exp(logits - max_value)
         | 
| 27 | 
            +
                exp_sum = np.sum(exp, axis=1, keepdims=True)
         | 
| 28 | 
            +
                dist = exp / exp_sum
         | 
| 29 | 
            +
                return dist
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def get_keep_pos_idxs(labels, remove_blank=None):
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                Remove duplicate and get pos idxs of keep items.
         | 
| 35 | 
            +
                The value of keep_blank should be [None, 95].
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                duplicate_len_list = []
         | 
| 38 | 
            +
                keep_pos_idx_list = []
         | 
| 39 | 
            +
                keep_char_idx_list = []
         | 
| 40 | 
            +
                for k, v_ in groupby(labels):
         | 
| 41 | 
            +
                    current_len = len(list(v_))
         | 
| 42 | 
            +
                    if k != remove_blank:
         | 
| 43 | 
            +
                        current_idx = int(sum(duplicate_len_list) + current_len // 2)
         | 
| 44 | 
            +
                        keep_pos_idx_list.append(current_idx)
         | 
| 45 | 
            +
                        keep_char_idx_list.append(k)
         | 
| 46 | 
            +
                    duplicate_len_list.append(current_len)
         | 
| 47 | 
            +
                return keep_char_idx_list, keep_pos_idx_list
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def remove_blank(labels, blank=0):
         | 
| 51 | 
            +
                new_labels = [x for x in labels if x != blank]
         | 
| 52 | 
            +
                return new_labels
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def insert_blank(labels, blank=0):
         | 
| 56 | 
            +
                new_labels = [blank]
         | 
| 57 | 
            +
                for l in labels:
         | 
| 58 | 
            +
                    new_labels += [l, blank]
         | 
| 59 | 
            +
                return new_labels
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                CTC greedy (best path) decoder.
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                raw_str = np.argmax(np.array(probs_seq), axis=1)
         | 
| 67 | 
            +
                remove_blank_in_pos = None if keep_blank_in_idxs else blank
         | 
| 68 | 
            +
                dedup_str, keep_idx_list = get_keep_pos_idxs(
         | 
| 69 | 
            +
                    raw_str, remove_blank=remove_blank_in_pos
         | 
| 70 | 
            +
                )
         | 
| 71 | 
            +
                dst_str = remove_blank(dedup_str, blank=blank)
         | 
| 72 | 
            +
                return dst_str, keep_idx_list
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
         | 
| 76 | 
            +
                _, _, C = logits_map.shape
         | 
| 77 | 
            +
                ys, xs = zip(*gather_info)
         | 
| 78 | 
            +
                logits_seq = logits_map[list(ys), list(xs)]
         | 
| 79 | 
            +
                probs_seq = logits_seq
         | 
| 80 | 
            +
                labels = np.argmax(probs_seq, axis=1)
         | 
| 81 | 
            +
                dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
         | 
| 82 | 
            +
                detal = len(gather_info) // (pts_num - 1)
         | 
| 83 | 
            +
                keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
         | 
| 84 | 
            +
                keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
         | 
| 85 | 
            +
                return dst_str, keep_gather_list
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6):
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                CTC decoder using multiple processes.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                decoder_str = []
         | 
| 93 | 
            +
                decoder_xys = []
         | 
| 94 | 
            +
                for gather_info in gather_info_list:
         | 
| 95 | 
            +
                    if len(gather_info) < pts_num:
         | 
| 96 | 
            +
                        continue
         | 
| 97 | 
            +
                    dst_str, xys_list = instance_ctc_greedy_decoder(
         | 
| 98 | 
            +
                        gather_info, logits_map, pts_num=pts_num
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
         | 
| 101 | 
            +
                    if len(dst_str_readable) < 2:
         | 
| 102 | 
            +
                        continue
         | 
| 103 | 
            +
                    decoder_str.append(dst_str_readable)
         | 
| 104 | 
            +
                    decoder_xys.append(xys_list)
         | 
| 105 | 
            +
                return decoder_str, decoder_xys
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def sort_with_direction(pos_list, f_direction):
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                f_direction: h x w x 2
         | 
| 111 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def sort_part_with_direction(pos_list, point_direction):
         | 
| 115 | 
            +
                    pos_list = np.array(pos_list).reshape(-1, 2)
         | 
| 116 | 
            +
                    point_direction = np.array(point_direction).reshape(-1, 2)
         | 
| 117 | 
            +
                    average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 118 | 
            +
                    pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 119 | 
            +
                    sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
         | 
| 120 | 
            +
                    sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
         | 
| 121 | 
            +
                    return sorted_list, sorted_direction
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                pos_list = np.array(pos_list).reshape(-1, 2)
         | 
| 124 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]  # x, y
         | 
| 125 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 126 | 
            +
                sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                point_num = len(sorted_point)
         | 
| 129 | 
            +
                if point_num >= 16:
         | 
| 130 | 
            +
                    middle_num = point_num // 2
         | 
| 131 | 
            +
                    first_part_point = sorted_point[:middle_num]
         | 
| 132 | 
            +
                    first_point_direction = sorted_direction[:middle_num]
         | 
| 133 | 
            +
                    sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
         | 
| 134 | 
            +
                        first_part_point, first_point_direction
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    last_part_point = sorted_point[middle_num:]
         | 
| 138 | 
            +
                    last_point_direction = sorted_direction[middle_num:]
         | 
| 139 | 
            +
                    sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
         | 
| 140 | 
            +
                        last_part_point, last_point_direction
         | 
| 141 | 
            +
                    )
         | 
| 142 | 
            +
                    sorted_point = sorted_fist_part_point + sorted_last_part_point
         | 
| 143 | 
            +
                    sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                return sorted_point, np.array(sorted_direction)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def add_id(pos_list, image_id=0):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Add id for gather feature, for inference.
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
                new_list = []
         | 
| 153 | 
            +
                for item in pos_list:
         | 
| 154 | 
            +
                    new_list.append((image_id, item[0], item[1]))
         | 
| 155 | 
            +
                return new_list
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def sort_and_expand_with_direction(pos_list, f_direction):
         | 
| 159 | 
            +
                """
         | 
| 160 | 
            +
                f_direction: h x w x 2
         | 
| 161 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                h, w, _ = f_direction.shape
         | 
| 164 | 
            +
                sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                point_num = len(sorted_list)
         | 
| 167 | 
            +
                sub_direction_len = max(point_num // 3, 2)
         | 
| 168 | 
            +
                left_direction = point_direction[:sub_direction_len, :]
         | 
| 169 | 
            +
                right_dirction = point_direction[point_num - sub_direction_len :, :]
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
         | 
| 172 | 
            +
                left_average_len = np.linalg.norm(left_average_direction)
         | 
| 173 | 
            +
                left_start = np.array(sorted_list[0])
         | 
| 174 | 
            +
                left_step = left_average_direction / (left_average_len + 1e-6)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
         | 
| 177 | 
            +
                right_average_len = np.linalg.norm(right_average_direction)
         | 
| 178 | 
            +
                right_step = right_average_direction / (right_average_len + 1e-6)
         | 
| 179 | 
            +
                right_start = np.array(sorted_list[-1])
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
         | 
| 182 | 
            +
                left_list = []
         | 
| 183 | 
            +
                right_list = []
         | 
| 184 | 
            +
                for i in range(append_num):
         | 
| 185 | 
            +
                    ly, lx = (
         | 
| 186 | 
            +
                        np.round(left_start + left_step * (i + 1))
         | 
| 187 | 
            +
                        .flatten()
         | 
| 188 | 
            +
                        .astype("int32")
         | 
| 189 | 
            +
                        .tolist()
         | 
| 190 | 
            +
                    )
         | 
| 191 | 
            +
                    if ly < h and lx < w and (ly, lx) not in left_list:
         | 
| 192 | 
            +
                        left_list.append((ly, lx))
         | 
| 193 | 
            +
                    ry, rx = (
         | 
| 194 | 
            +
                        np.round(right_start + right_step * (i + 1))
         | 
| 195 | 
            +
                        .flatten()
         | 
| 196 | 
            +
                        .astype("int32")
         | 
| 197 | 
            +
                        .tolist()
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    if ry < h and rx < w and (ry, rx) not in right_list:
         | 
| 200 | 
            +
                        right_list.append((ry, rx))
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                all_list = left_list[::-1] + sorted_list + right_list
         | 
| 203 | 
            +
                return all_list
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
         | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
                f_direction: h x w x 2
         | 
| 209 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 210 | 
            +
                binary_tcl_map: h x w
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
                h, w, _ = f_direction.shape
         | 
| 213 | 
            +
                sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                point_num = len(sorted_list)
         | 
| 216 | 
            +
                sub_direction_len = max(point_num // 3, 2)
         | 
| 217 | 
            +
                left_direction = point_direction[:sub_direction_len, :]
         | 
| 218 | 
            +
                right_dirction = point_direction[point_num - sub_direction_len :, :]
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
         | 
| 221 | 
            +
                left_average_len = np.linalg.norm(left_average_direction)
         | 
| 222 | 
            +
                left_start = np.array(sorted_list[0])
         | 
| 223 | 
            +
                left_step = left_average_direction / (left_average_len + 1e-6)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
         | 
| 226 | 
            +
                right_average_len = np.linalg.norm(right_average_direction)
         | 
| 227 | 
            +
                right_step = right_average_direction / (right_average_len + 1e-6)
         | 
| 228 | 
            +
                right_start = np.array(sorted_list[-1])
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
         | 
| 231 | 
            +
                max_append_num = 2 * append_num
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                left_list = []
         | 
| 234 | 
            +
                right_list = []
         | 
| 235 | 
            +
                for i in range(max_append_num):
         | 
| 236 | 
            +
                    ly, lx = (
         | 
| 237 | 
            +
                        np.round(left_start + left_step * (i + 1))
         | 
| 238 | 
            +
                        .flatten()
         | 
| 239 | 
            +
                        .astype("int32")
         | 
| 240 | 
            +
                        .tolist()
         | 
| 241 | 
            +
                    )
         | 
| 242 | 
            +
                    if ly < h and lx < w and (ly, lx) not in left_list:
         | 
| 243 | 
            +
                        if binary_tcl_map[ly, lx] > 0.5:
         | 
| 244 | 
            +
                            left_list.append((ly, lx))
         | 
| 245 | 
            +
                        else:
         | 
| 246 | 
            +
                            break
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                for i in range(max_append_num):
         | 
| 249 | 
            +
                    ry, rx = (
         | 
| 250 | 
            +
                        np.round(right_start + right_step * (i + 1))
         | 
| 251 | 
            +
                        .flatten()
         | 
| 252 | 
            +
                        .astype("int32")
         | 
| 253 | 
            +
                        .tolist()
         | 
| 254 | 
            +
                    )
         | 
| 255 | 
            +
                    if ry < h and rx < w and (ry, rx) not in right_list:
         | 
| 256 | 
            +
                        if binary_tcl_map[ry, rx] > 0.5:
         | 
| 257 | 
            +
                            right_list.append((ry, rx))
         | 
| 258 | 
            +
                        else:
         | 
| 259 | 
            +
                            break
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                all_list = left_list[::-1] + sorted_list + right_list
         | 
| 262 | 
            +
                return all_list
         | 
| 263 | 
            +
             | 
| 264 | 
            +
             | 
| 265 | 
            +
            def point_pair2poly(point_pair_list):
         | 
| 266 | 
            +
                """
         | 
| 267 | 
            +
                Transfer vertical point_pairs into poly point in clockwise.
         | 
| 268 | 
            +
                """
         | 
| 269 | 
            +
                point_num = len(point_pair_list) * 2
         | 
| 270 | 
            +
                point_list = [0] * point_num
         | 
| 271 | 
            +
                for idx, point_pair in enumerate(point_pair_list):
         | 
| 272 | 
            +
                    point_list[idx] = point_pair[0]
         | 
| 273 | 
            +
                    point_list[point_num - 1 - idx] = point_pair[1]
         | 
| 274 | 
            +
                return np.array(point_list).reshape(-1, 2)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
         | 
| 278 | 
            +
                ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
         | 
| 279 | 
            +
                p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
         | 
| 280 | 
            +
                p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
         | 
| 281 | 
            +
                return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
            def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
         | 
| 285 | 
            +
                """
         | 
| 286 | 
            +
                expand poly along width.
         | 
| 287 | 
            +
                """
         | 
| 288 | 
            +
                point_num = poly.shape[0]
         | 
| 289 | 
            +
                left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
         | 
| 290 | 
            +
                left_ratio = (
         | 
| 291 | 
            +
                    -shrink_ratio_of_width
         | 
| 292 | 
            +
                    * np.linalg.norm(left_quad[0] - left_quad[3])
         | 
| 293 | 
            +
                    / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
         | 
| 294 | 
            +
                )
         | 
| 295 | 
            +
                left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
         | 
| 296 | 
            +
                right_quad = np.array(
         | 
| 297 | 
            +
                    [
         | 
| 298 | 
            +
                        poly[point_num // 2 - 2],
         | 
| 299 | 
            +
                        poly[point_num // 2 - 1],
         | 
| 300 | 
            +
                        poly[point_num // 2],
         | 
| 301 | 
            +
                        poly[point_num // 2 + 1],
         | 
| 302 | 
            +
                    ],
         | 
| 303 | 
            +
                    dtype=np.float32,
         | 
| 304 | 
            +
                )
         | 
| 305 | 
            +
                right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
         | 
| 306 | 
            +
                    right_quad[0] - right_quad[3]
         | 
| 307 | 
            +
                ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
         | 
| 308 | 
            +
                right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
         | 
| 309 | 
            +
                poly[0] = left_quad_expand[0]
         | 
| 310 | 
            +
                poly[-1] = left_quad_expand[-1]
         | 
| 311 | 
            +
                poly[point_num // 2 - 1] = right_quad_expand[1]
         | 
| 312 | 
            +
                poly[point_num // 2] = right_quad_expand[2]
         | 
| 313 | 
            +
                return poly
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            def restore_poly(
         | 
| 317 | 
            +
                instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set
         | 
| 318 | 
            +
            ):
         | 
| 319 | 
            +
                poly_list = []
         | 
| 320 | 
            +
                keep_str_list = []
         | 
| 321 | 
            +
                for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
         | 
| 322 | 
            +
                    if len(keep_str) < 2:
         | 
| 323 | 
            +
                        print("--> too short, {}".format(keep_str))
         | 
| 324 | 
            +
                        continue
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    offset_expand = 1.0
         | 
| 327 | 
            +
                    if valid_set == "totaltext":
         | 
| 328 | 
            +
                        offset_expand = 1.2
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    point_pair_list = []
         | 
| 331 | 
            +
                    for y, x in yx_center_line:
         | 
| 332 | 
            +
                        offset = p_border[:, y, x].reshape(2, 2) * offset_expand
         | 
| 333 | 
            +
                        ori_yx = np.array([y, x], dtype=np.float32)
         | 
| 334 | 
            +
                        point_pair = (
         | 
| 335 | 
            +
                            (ori_yx + offset)[:, ::-1]
         | 
| 336 | 
            +
                            * 4.0
         | 
| 337 | 
            +
                            / np.array([ratio_w, ratio_h]).reshape(-1, 2)
         | 
| 338 | 
            +
                        )
         | 
| 339 | 
            +
                        point_pair_list.append(point_pair)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    detected_poly = point_pair2poly(point_pair_list)
         | 
| 342 | 
            +
                    detected_poly = expand_poly_along_width(
         | 
| 343 | 
            +
                        detected_poly, shrink_ratio_of_width=0.2
         | 
| 344 | 
            +
                    )
         | 
| 345 | 
            +
                    detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
         | 
| 346 | 
            +
                    detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    keep_str_list.append(keep_str)
         | 
| 349 | 
            +
                    if valid_set == "partvgg":
         | 
| 350 | 
            +
                        middle_point = len(detected_poly) // 2
         | 
| 351 | 
            +
                        detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :]
         | 
| 352 | 
            +
                        poly_list.append(detected_poly)
         | 
| 353 | 
            +
                    elif valid_set == "totaltext":
         | 
| 354 | 
            +
                        poly_list.append(detected_poly)
         | 
| 355 | 
            +
                    else:
         | 
| 356 | 
            +
                        print("--> Not supported format.")
         | 
| 357 | 
            +
                        exit(-1)
         | 
| 358 | 
            +
                return poly_list, keep_str_list
         | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
            +
            def generate_pivot_list_fast(
         | 
| 362 | 
            +
                p_score, p_char_maps, f_direction, Lexicon_Table, score_thresh=0.5
         | 
| 363 | 
            +
            ):
         | 
| 364 | 
            +
                """
         | 
| 365 | 
            +
                return center point and end point of TCL instance; filter with the char maps;
         | 
| 366 | 
            +
                """
         | 
| 367 | 
            +
                p_score = p_score[0]
         | 
| 368 | 
            +
                f_direction = f_direction.transpose(1, 2, 0)
         | 
| 369 | 
            +
                p_tcl_map = (p_score > score_thresh) * 1.0
         | 
| 370 | 
            +
                skeleton_map = thin(p_tcl_map.astype(np.uint8))
         | 
| 371 | 
            +
                instance_count, instance_label_map = cv2.connectedComponents(
         | 
| 372 | 
            +
                    skeleton_map.astype(np.uint8), connectivity=8
         | 
| 373 | 
            +
                )
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                # get TCL Instance
         | 
| 376 | 
            +
                all_pos_yxs = []
         | 
| 377 | 
            +
                if instance_count > 0:
         | 
| 378 | 
            +
                    for instance_id in range(1, instance_count):
         | 
| 379 | 
            +
                        pos_list = []
         | 
| 380 | 
            +
                        ys, xs = np.where(instance_label_map == instance_id)
         | 
| 381 | 
            +
                        pos_list = list(zip(ys, xs))
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                        if len(pos_list) < 3:
         | 
| 384 | 
            +
                            continue
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                        pos_list_sorted = sort_and_expand_with_direction_v2(
         | 
| 387 | 
            +
                            pos_list, f_direction, p_tcl_map
         | 
| 388 | 
            +
                        )
         | 
| 389 | 
            +
                        all_pos_yxs.append(pos_list_sorted)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                p_char_maps = p_char_maps.transpose([1, 2, 0])
         | 
| 392 | 
            +
                decoded_str, keep_yxs_list = ctc_decoder_for_image(
         | 
| 393 | 
            +
                    all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table
         | 
| 394 | 
            +
                )
         | 
| 395 | 
            +
                return keep_yxs_list, decoded_str
         | 
| 396 | 
            +
             | 
| 397 | 
            +
             | 
| 398 | 
            +
            def extract_main_direction(pos_list, f_direction):
         | 
| 399 | 
            +
                """
         | 
| 400 | 
            +
                f_direction: h x w x 2
         | 
| 401 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 402 | 
            +
                """
         | 
| 403 | 
            +
                pos_list = np.array(pos_list)
         | 
| 404 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
         | 
| 405 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 406 | 
            +
                average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 407 | 
            +
                average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
         | 
| 408 | 
            +
                return average_direction
         | 
| 409 | 
            +
             | 
| 410 | 
            +
             | 
| 411 | 
            +
            def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
         | 
| 412 | 
            +
                """
         | 
| 413 | 
            +
                f_direction: h x w x 2
         | 
| 414 | 
            +
                pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
         | 
| 415 | 
            +
                """
         | 
| 416 | 
            +
                pos_list_full = np.array(pos_list).reshape(-1, 3)
         | 
| 417 | 
            +
                pos_list = pos_list_full[:, 1:]
         | 
| 418 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]  # x, y
         | 
| 419 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 420 | 
            +
                average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 421 | 
            +
                pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 422 | 
            +
                sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
         | 
| 423 | 
            +
                return sorted_list
         | 
| 424 | 
            +
             | 
| 425 | 
            +
             | 
| 426 | 
            +
            def sort_by_direction_with_image_id(pos_list, f_direction):
         | 
| 427 | 
            +
                """
         | 
| 428 | 
            +
                f_direction: h x w x 2
         | 
| 429 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 430 | 
            +
                """
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def sort_part_with_direction(pos_list_full, point_direction):
         | 
| 433 | 
            +
                    pos_list_full = np.array(pos_list_full).reshape(-1, 3)
         | 
| 434 | 
            +
                    pos_list = pos_list_full[:, 1:]
         | 
| 435 | 
            +
                    point_direction = np.array(point_direction).reshape(-1, 2)
         | 
| 436 | 
            +
                    average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 437 | 
            +
                    pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 438 | 
            +
                    sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
         | 
| 439 | 
            +
                    sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
         | 
| 440 | 
            +
                    return sorted_list, sorted_direction
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                pos_list = np.array(pos_list).reshape(-1, 3)
         | 
| 443 | 
            +
                point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]]  # x, y
         | 
| 444 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 445 | 
            +
                sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                point_num = len(sorted_point)
         | 
| 448 | 
            +
                if point_num >= 16:
         | 
| 449 | 
            +
                    middle_num = point_num // 2
         | 
| 450 | 
            +
                    first_part_point = sorted_point[:middle_num]
         | 
| 451 | 
            +
                    first_point_direction = sorted_direction[:middle_num]
         | 
| 452 | 
            +
                    sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
         | 
| 453 | 
            +
                        first_part_point, first_point_direction
         | 
| 454 | 
            +
                    )
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    last_part_point = sorted_point[middle_num:]
         | 
| 457 | 
            +
                    last_point_direction = sorted_direction[middle_num:]
         | 
| 458 | 
            +
                    sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
         | 
| 459 | 
            +
                        last_part_point, last_point_direction
         | 
| 460 | 
            +
                    )
         | 
| 461 | 
            +
                    sorted_point = sorted_fist_part_point + sorted_last_part_point
         | 
| 462 | 
            +
                    sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                return sorted_point
         | 
    	
        ocr/postprocess/extract_textpoint_slow.py
    ADDED
    
    | @@ -0,0 +1,608 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from itertools import groupby
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import cv2
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            from skimage.morphology._skeletonize import thin
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_dict(character_dict_path):
         | 
| 12 | 
            +
                character_str = ""
         | 
| 13 | 
            +
                with open(character_dict_path, "rb") as fin:
         | 
| 14 | 
            +
                    lines = fin.readlines()
         | 
| 15 | 
            +
                    for line in lines:
         | 
| 16 | 
            +
                        line = line.decode("utf-8").strip("\n").strip("\r\n")
         | 
| 17 | 
            +
                        character_str += line
         | 
| 18 | 
            +
                    dict_character = list(character_str)
         | 
| 19 | 
            +
                return dict_character
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def point_pair2poly(point_pair_list):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Transfer vertical point_pairs into poly point in clockwise.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                pair_length_list = []
         | 
| 27 | 
            +
                for point_pair in point_pair_list:
         | 
| 28 | 
            +
                    pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
         | 
| 29 | 
            +
                    pair_length_list.append(pair_length)
         | 
| 30 | 
            +
                pair_length_list = np.array(pair_length_list)
         | 
| 31 | 
            +
                pair_info = (
         | 
| 32 | 
            +
                    pair_length_list.max(),
         | 
| 33 | 
            +
                    pair_length_list.min(),
         | 
| 34 | 
            +
                    pair_length_list.mean(),
         | 
| 35 | 
            +
                )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                point_num = len(point_pair_list) * 2
         | 
| 38 | 
            +
                point_list = [0] * point_num
         | 
| 39 | 
            +
                for idx, point_pair in enumerate(point_pair_list):
         | 
| 40 | 
            +
                    point_list[idx] = point_pair[0]
         | 
| 41 | 
            +
                    point_list[point_num - 1 - idx] = point_pair[1]
         | 
| 42 | 
            +
                return np.array(point_list).reshape(-1, 2), pair_info
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                Generate shrink_quad_along_width.
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
         | 
| 50 | 
            +
                p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
         | 
| 51 | 
            +
                p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
         | 
| 52 | 
            +
                return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                expand poly along width.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                point_num = poly.shape[0]
         | 
| 60 | 
            +
                left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
         | 
| 61 | 
            +
                left_ratio = (
         | 
| 62 | 
            +
                    -shrink_ratio_of_width
         | 
| 63 | 
            +
                    * np.linalg.norm(left_quad[0] - left_quad[3])
         | 
| 64 | 
            +
                    / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
         | 
| 65 | 
            +
                )
         | 
| 66 | 
            +
                left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
         | 
| 67 | 
            +
                right_quad = np.array(
         | 
| 68 | 
            +
                    [
         | 
| 69 | 
            +
                        poly[point_num // 2 - 2],
         | 
| 70 | 
            +
                        poly[point_num // 2 - 1],
         | 
| 71 | 
            +
                        poly[point_num // 2],
         | 
| 72 | 
            +
                        poly[point_num // 2 + 1],
         | 
| 73 | 
            +
                    ],
         | 
| 74 | 
            +
                    dtype=np.float32,
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
         | 
| 77 | 
            +
                    right_quad[0] - right_quad[3]
         | 
| 78 | 
            +
                ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
         | 
| 79 | 
            +
                right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
         | 
| 80 | 
            +
                poly[0] = left_quad_expand[0]
         | 
| 81 | 
            +
                poly[-1] = left_quad_expand[-1]
         | 
| 82 | 
            +
                poly[point_num // 2 - 1] = right_quad_expand[1]
         | 
| 83 | 
            +
                poly[point_num // 2] = right_quad_expand[2]
         | 
| 84 | 
            +
                return poly
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def softmax(logits):
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                logits: N x d
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                max_value = np.max(logits, axis=1, keepdims=True)
         | 
| 92 | 
            +
                exp = np.exp(logits - max_value)
         | 
| 93 | 
            +
                exp_sum = np.sum(exp, axis=1, keepdims=True)
         | 
| 94 | 
            +
                dist = exp / exp_sum
         | 
| 95 | 
            +
                return dist
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def get_keep_pos_idxs(labels, remove_blank=None):
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                Remove duplicate and get pos idxs of keep items.
         | 
| 101 | 
            +
                The value of keep_blank should be [None, 95].
         | 
| 102 | 
            +
                """
         | 
| 103 | 
            +
                duplicate_len_list = []
         | 
| 104 | 
            +
                keep_pos_idx_list = []
         | 
| 105 | 
            +
                keep_char_idx_list = []
         | 
| 106 | 
            +
                for k, v_ in groupby(labels):
         | 
| 107 | 
            +
                    current_len = len(list(v_))
         | 
| 108 | 
            +
                    if k != remove_blank:
         | 
| 109 | 
            +
                        current_idx = int(sum(duplicate_len_list) + current_len // 2)
         | 
| 110 | 
            +
                        keep_pos_idx_list.append(current_idx)
         | 
| 111 | 
            +
                        keep_char_idx_list.append(k)
         | 
| 112 | 
            +
                    duplicate_len_list.append(current_len)
         | 
| 113 | 
            +
                return keep_char_idx_list, keep_pos_idx_list
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def remove_blank(labels, blank=0):
         | 
| 117 | 
            +
                new_labels = [x for x in labels if x != blank]
         | 
| 118 | 
            +
                return new_labels
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def insert_blank(labels, blank=0):
         | 
| 122 | 
            +
                new_labels = [blank]
         | 
| 123 | 
            +
                for l in labels:
         | 
| 124 | 
            +
                    new_labels += [l, blank]
         | 
| 125 | 
            +
                return new_labels
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                CTC greedy (best path) decoder.
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                raw_str = np.argmax(np.array(probs_seq), axis=1)
         | 
| 133 | 
            +
                remove_blank_in_pos = None if keep_blank_in_idxs else blank
         | 
| 134 | 
            +
                dedup_str, keep_idx_list = get_keep_pos_idxs(
         | 
| 135 | 
            +
                    raw_str, remove_blank=remove_blank_in_pos
         | 
| 136 | 
            +
                )
         | 
| 137 | 
            +
                dst_str = remove_blank(dedup_str, blank=blank)
         | 
| 138 | 
            +
                return dst_str, keep_idx_list
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def instance_ctc_greedy_decoder(gather_info, logits_map, keep_blank_in_idxs=True):
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                gather_info: [[x, y], [x, y] ...]
         | 
| 144 | 
            +
                logits_map: H x W X (n_chars + 1)
         | 
| 145 | 
            +
                """
         | 
| 146 | 
            +
                _, _, C = logits_map.shape
         | 
| 147 | 
            +
                ys, xs = zip(*gather_info)
         | 
| 148 | 
            +
                logits_seq = logits_map[list(ys), list(xs)]  # n x 96
         | 
| 149 | 
            +
                probs_seq = softmax(logits_seq)
         | 
| 150 | 
            +
                dst_str, keep_idx_list = ctc_greedy_decoder(
         | 
| 151 | 
            +
                    probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs
         | 
| 152 | 
            +
                )
         | 
| 153 | 
            +
                keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
         | 
| 154 | 
            +
                return dst_str, keep_gather_list
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            def ctc_decoder_for_image(gather_info_list, logits_map, keep_blank_in_idxs=True):
         | 
| 158 | 
            +
                """
         | 
| 159 | 
            +
                CTC decoder using multiple processes.
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
                decoder_results = []
         | 
| 162 | 
            +
                for gather_info in gather_info_list:
         | 
| 163 | 
            +
                    res = instance_ctc_greedy_decoder(
         | 
| 164 | 
            +
                        gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs
         | 
| 165 | 
            +
                    )
         | 
| 166 | 
            +
                    decoder_results.append(res)
         | 
| 167 | 
            +
                return decoder_results
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def sort_with_direction(pos_list, f_direction):
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                f_direction: h x w x 2
         | 
| 173 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 174 | 
            +
                """
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def sort_part_with_direction(pos_list, point_direction):
         | 
| 177 | 
            +
                    pos_list = np.array(pos_list).reshape(-1, 2)
         | 
| 178 | 
            +
                    point_direction = np.array(point_direction).reshape(-1, 2)
         | 
| 179 | 
            +
                    average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 180 | 
            +
                    pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 181 | 
            +
                    sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
         | 
| 182 | 
            +
                    sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
         | 
| 183 | 
            +
                    return sorted_list, sorted_direction
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                pos_list = np.array(pos_list).reshape(-1, 2)
         | 
| 186 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]  # x, y
         | 
| 187 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 188 | 
            +
                sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                point_num = len(sorted_point)
         | 
| 191 | 
            +
                if point_num >= 16:
         | 
| 192 | 
            +
                    middle_num = point_num // 2
         | 
| 193 | 
            +
                    first_part_point = sorted_point[:middle_num]
         | 
| 194 | 
            +
                    first_point_direction = sorted_direction[:middle_num]
         | 
| 195 | 
            +
                    sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
         | 
| 196 | 
            +
                        first_part_point, first_point_direction
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    last_part_point = sorted_point[middle_num:]
         | 
| 200 | 
            +
                    last_point_direction = sorted_direction[middle_num:]
         | 
| 201 | 
            +
                    sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
         | 
| 202 | 
            +
                        last_part_point, last_point_direction
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    sorted_point = sorted_fist_part_point + sorted_last_part_point
         | 
| 205 | 
            +
                    sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return sorted_point, np.array(sorted_direction)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def add_id(pos_list, image_id=0):
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
                Add id for gather feature, for inference.
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
                new_list = []
         | 
| 215 | 
            +
                for item in pos_list:
         | 
| 216 | 
            +
                    new_list.append((image_id, item[0], item[1]))
         | 
| 217 | 
            +
                return new_list
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def sort_and_expand_with_direction(pos_list, f_direction):
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                f_direction: h x w x 2
         | 
| 223 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                h, w, _ = f_direction.shape
         | 
| 226 | 
            +
                sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                # expand along
         | 
| 229 | 
            +
                point_num = len(sorted_list)
         | 
| 230 | 
            +
                sub_direction_len = max(point_num // 3, 2)
         | 
| 231 | 
            +
                left_direction = point_direction[:sub_direction_len, :]
         | 
| 232 | 
            +
                right_dirction = point_direction[point_num - sub_direction_len :, :]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
         | 
| 235 | 
            +
                left_average_len = np.linalg.norm(left_average_direction)
         | 
| 236 | 
            +
                left_start = np.array(sorted_list[0])
         | 
| 237 | 
            +
                left_step = left_average_direction / (left_average_len + 1e-6)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
         | 
| 240 | 
            +
                right_average_len = np.linalg.norm(right_average_direction)
         | 
| 241 | 
            +
                right_step = right_average_direction / (right_average_len + 1e-6)
         | 
| 242 | 
            +
                right_start = np.array(sorted_list[-1])
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
         | 
| 245 | 
            +
                left_list = []
         | 
| 246 | 
            +
                right_list = []
         | 
| 247 | 
            +
                for i in range(append_num):
         | 
| 248 | 
            +
                    ly, lx = (
         | 
| 249 | 
            +
                        np.round(left_start + left_step * (i + 1))
         | 
| 250 | 
            +
                        .flatten()
         | 
| 251 | 
            +
                        .astype("int32")
         | 
| 252 | 
            +
                        .tolist()
         | 
| 253 | 
            +
                    )
         | 
| 254 | 
            +
                    if ly < h and lx < w and (ly, lx) not in left_list:
         | 
| 255 | 
            +
                        left_list.append((ly, lx))
         | 
| 256 | 
            +
                    ry, rx = (
         | 
| 257 | 
            +
                        np.round(right_start + right_step * (i + 1))
         | 
| 258 | 
            +
                        .flatten()
         | 
| 259 | 
            +
                        .astype("int32")
         | 
| 260 | 
            +
                        .tolist()
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
                    if ry < h and rx < w and (ry, rx) not in right_list:
         | 
| 263 | 
            +
                        right_list.append((ry, rx))
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                all_list = left_list[::-1] + sorted_list + right_list
         | 
| 266 | 
            +
                return all_list
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
         | 
| 270 | 
            +
                """
         | 
| 271 | 
            +
                f_direction: h x w x 2
         | 
| 272 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 273 | 
            +
                binary_tcl_map: h x w
         | 
| 274 | 
            +
                """
         | 
| 275 | 
            +
                h, w, _ = f_direction.shape
         | 
| 276 | 
            +
                sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                # expand along
         | 
| 279 | 
            +
                point_num = len(sorted_list)
         | 
| 280 | 
            +
                sub_direction_len = max(point_num // 3, 2)
         | 
| 281 | 
            +
                left_direction = point_direction[:sub_direction_len, :]
         | 
| 282 | 
            +
                right_dirction = point_direction[point_num - sub_direction_len :, :]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
         | 
| 285 | 
            +
                left_average_len = np.linalg.norm(left_average_direction)
         | 
| 286 | 
            +
                left_start = np.array(sorted_list[0])
         | 
| 287 | 
            +
                left_step = left_average_direction / (left_average_len + 1e-6)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
         | 
| 290 | 
            +
                right_average_len = np.linalg.norm(right_average_direction)
         | 
| 291 | 
            +
                right_step = right_average_direction / (right_average_len + 1e-6)
         | 
| 292 | 
            +
                right_start = np.array(sorted_list[-1])
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
         | 
| 295 | 
            +
                max_append_num = 2 * append_num
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                left_list = []
         | 
| 298 | 
            +
                right_list = []
         | 
| 299 | 
            +
                for i in range(max_append_num):
         | 
| 300 | 
            +
                    ly, lx = (
         | 
| 301 | 
            +
                        np.round(left_start + left_step * (i + 1))
         | 
| 302 | 
            +
                        .flatten()
         | 
| 303 | 
            +
                        .astype("int32")
         | 
| 304 | 
            +
                        .tolist()
         | 
| 305 | 
            +
                    )
         | 
| 306 | 
            +
                    if ly < h and lx < w and (ly, lx) not in left_list:
         | 
| 307 | 
            +
                        if binary_tcl_map[ly, lx] > 0.5:
         | 
| 308 | 
            +
                            left_list.append((ly, lx))
         | 
| 309 | 
            +
                        else:
         | 
| 310 | 
            +
                            break
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                for i in range(max_append_num):
         | 
| 313 | 
            +
                    ry, rx = (
         | 
| 314 | 
            +
                        np.round(right_start + right_step * (i + 1))
         | 
| 315 | 
            +
                        .flatten()
         | 
| 316 | 
            +
                        .astype("int32")
         | 
| 317 | 
            +
                        .tolist()
         | 
| 318 | 
            +
                    )
         | 
| 319 | 
            +
                    if ry < h and rx < w and (ry, rx) not in right_list:
         | 
| 320 | 
            +
                        if binary_tcl_map[ry, rx] > 0.5:
         | 
| 321 | 
            +
                            right_list.append((ry, rx))
         | 
| 322 | 
            +
                        else:
         | 
| 323 | 
            +
                            break
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                all_list = left_list[::-1] + sorted_list + right_list
         | 
| 326 | 
            +
                return all_list
         | 
| 327 | 
            +
             | 
| 328 | 
            +
             | 
| 329 | 
            +
            def generate_pivot_list_curved(
         | 
| 330 | 
            +
                p_score,
         | 
| 331 | 
            +
                p_char_maps,
         | 
| 332 | 
            +
                f_direction,
         | 
| 333 | 
            +
                score_thresh=0.5,
         | 
| 334 | 
            +
                is_expand=True,
         | 
| 335 | 
            +
                is_backbone=False,
         | 
| 336 | 
            +
                image_id=0,
         | 
| 337 | 
            +
            ):
         | 
| 338 | 
            +
                """
         | 
| 339 | 
            +
                return center point and end point of TCL instance; filter with the char maps;
         | 
| 340 | 
            +
                """
         | 
| 341 | 
            +
                p_score = p_score[0]
         | 
| 342 | 
            +
                f_direction = f_direction.transpose(1, 2, 0)
         | 
| 343 | 
            +
                p_tcl_map = (p_score > score_thresh) * 1.0
         | 
| 344 | 
            +
                skeleton_map = thin(p_tcl_map)
         | 
| 345 | 
            +
                instance_count, instance_label_map = cv2.connectedComponents(
         | 
| 346 | 
            +
                    skeleton_map.astype(np.uint8), connectivity=8
         | 
| 347 | 
            +
                )
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # get TCL Instance
         | 
| 350 | 
            +
                all_pos_yxs = []
         | 
| 351 | 
            +
                center_pos_yxs = []
         | 
| 352 | 
            +
                end_points_yxs = []
         | 
| 353 | 
            +
                instance_center_pos_yxs = []
         | 
| 354 | 
            +
                pred_strs = []
         | 
| 355 | 
            +
                if instance_count > 0:
         | 
| 356 | 
            +
                    for instance_id in range(1, instance_count):
         | 
| 357 | 
            +
                        pos_list = []
         | 
| 358 | 
            +
                        ys, xs = np.where(instance_label_map == instance_id)
         | 
| 359 | 
            +
                        pos_list = list(zip(ys, xs))
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        ### FIX-ME, eliminate outlier
         | 
| 362 | 
            +
                        if len(pos_list) < 3:
         | 
| 363 | 
            +
                            continue
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                        if is_expand:
         | 
| 366 | 
            +
                            pos_list_sorted = sort_and_expand_with_direction_v2(
         | 
| 367 | 
            +
                                pos_list, f_direction, p_tcl_map
         | 
| 368 | 
            +
                            )
         | 
| 369 | 
            +
                        else:
         | 
| 370 | 
            +
                            pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
         | 
| 371 | 
            +
                        all_pos_yxs.append(pos_list_sorted)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                # use decoder to filter backgroud points.
         | 
| 374 | 
            +
                p_char_maps = p_char_maps.transpose([1, 2, 0])
         | 
| 375 | 
            +
                decode_res = ctc_decoder_for_image(
         | 
| 376 | 
            +
                    all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
         | 
| 377 | 
            +
                )
         | 
| 378 | 
            +
                for decoded_str, keep_yxs_list in decode_res:
         | 
| 379 | 
            +
                    if is_backbone:
         | 
| 380 | 
            +
                        keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
         | 
| 381 | 
            +
                        instance_center_pos_yxs.append(keep_yxs_list_with_id)
         | 
| 382 | 
            +
                        pred_strs.append(decoded_str)
         | 
| 383 | 
            +
                    else:
         | 
| 384 | 
            +
                        end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
         | 
| 385 | 
            +
                        center_pos_yxs.extend(keep_yxs_list)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                if is_backbone:
         | 
| 388 | 
            +
                    return pred_strs, instance_center_pos_yxs
         | 
| 389 | 
            +
                else:
         | 
| 390 | 
            +
                    return center_pos_yxs, end_points_yxs
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            def generate_pivot_list_horizontal(
         | 
| 394 | 
            +
                p_score, p_char_maps, f_direction, score_thresh=0.5, is_backbone=False, image_id=0
         | 
| 395 | 
            +
            ):
         | 
| 396 | 
            +
                """
         | 
| 397 | 
            +
                return center point and end point of TCL instance; filter with the char maps;
         | 
| 398 | 
            +
                """
         | 
| 399 | 
            +
                p_score = p_score[0]
         | 
| 400 | 
            +
                f_direction = f_direction.transpose(1, 2, 0)
         | 
| 401 | 
            +
                p_tcl_map_bi = (p_score > score_thresh) * 1.0
         | 
| 402 | 
            +
                instance_count, instance_label_map = cv2.connectedComponents(
         | 
| 403 | 
            +
                    p_tcl_map_bi.astype(np.uint8), connectivity=8
         | 
| 404 | 
            +
                )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                # get TCL Instance
         | 
| 407 | 
            +
                all_pos_yxs = []
         | 
| 408 | 
            +
                center_pos_yxs = []
         | 
| 409 | 
            +
                end_points_yxs = []
         | 
| 410 | 
            +
                instance_center_pos_yxs = []
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                if instance_count > 0:
         | 
| 413 | 
            +
                    for instance_id in range(1, instance_count):
         | 
| 414 | 
            +
                        pos_list = []
         | 
| 415 | 
            +
                        ys, xs = np.where(instance_label_map == instance_id)
         | 
| 416 | 
            +
                        pos_list = list(zip(ys, xs))
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                        ### FIX-ME, eliminate outlier
         | 
| 419 | 
            +
                        if len(pos_list) < 5:
         | 
| 420 | 
            +
                            continue
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                        # add rule here
         | 
| 423 | 
            +
                        main_direction = extract_main_direction(pos_list, f_direction)  # y x
         | 
| 424 | 
            +
                        reference_directin = np.array([0, 1]).reshape([-1, 2])  # y x
         | 
| 425 | 
            +
                        is_h_angle = abs(np.sum(main_direction * reference_directin)) < math.cos(
         | 
| 426 | 
            +
                            math.pi / 180 * 70
         | 
| 427 | 
            +
                        )
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                        point_yxs = np.array(pos_list)
         | 
| 430 | 
            +
                        max_y, max_x = np.max(point_yxs, axis=0)
         | 
| 431 | 
            +
                        min_y, min_x = np.min(point_yxs, axis=0)
         | 
| 432 | 
            +
                        is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        pos_list_final = []
         | 
| 435 | 
            +
                        if is_h_len:
         | 
| 436 | 
            +
                            xs = np.unique(xs)
         | 
| 437 | 
            +
                            for x in xs:
         | 
| 438 | 
            +
                                ys = instance_label_map[:, x].copy().reshape((-1,))
         | 
| 439 | 
            +
                                y = int(np.where(ys == instance_id)[0].mean())
         | 
| 440 | 
            +
                                pos_list_final.append((y, x))
         | 
| 441 | 
            +
                        else:
         | 
| 442 | 
            +
                            ys = np.unique(ys)
         | 
| 443 | 
            +
                            for y in ys:
         | 
| 444 | 
            +
                                xs = instance_label_map[y, :].copy().reshape((-1,))
         | 
| 445 | 
            +
                                x = int(np.where(xs == instance_id)[0].mean())
         | 
| 446 | 
            +
                                pos_list_final.append((y, x))
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                        pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction)
         | 
| 449 | 
            +
                        all_pos_yxs.append(pos_list_sorted)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                # use decoder to filter backgroud points.
         | 
| 452 | 
            +
                p_char_maps = p_char_maps.transpose([1, 2, 0])
         | 
| 453 | 
            +
                decode_res = ctc_decoder_for_image(
         | 
| 454 | 
            +
                    all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
         | 
| 455 | 
            +
                )
         | 
| 456 | 
            +
                for decoded_str, keep_yxs_list in decode_res:
         | 
| 457 | 
            +
                    if is_backbone:
         | 
| 458 | 
            +
                        keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
         | 
| 459 | 
            +
                        instance_center_pos_yxs.append(keep_yxs_list_with_id)
         | 
| 460 | 
            +
                    else:
         | 
| 461 | 
            +
                        end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
         | 
| 462 | 
            +
                        center_pos_yxs.extend(keep_yxs_list)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                if is_backbone:
         | 
| 465 | 
            +
                    return instance_center_pos_yxs
         | 
| 466 | 
            +
                else:
         | 
| 467 | 
            +
                    return center_pos_yxs, end_points_yxs
         | 
| 468 | 
            +
             | 
| 469 | 
            +
             | 
| 470 | 
            +
            def generate_pivot_list_slow(
         | 
| 471 | 
            +
                p_score,
         | 
| 472 | 
            +
                p_char_maps,
         | 
| 473 | 
            +
                f_direction,
         | 
| 474 | 
            +
                score_thresh=0.5,
         | 
| 475 | 
            +
                is_backbone=False,
         | 
| 476 | 
            +
                is_curved=True,
         | 
| 477 | 
            +
                image_id=0,
         | 
| 478 | 
            +
            ):
         | 
| 479 | 
            +
                """
         | 
| 480 | 
            +
                Warp all the function together.
         | 
| 481 | 
            +
                """
         | 
| 482 | 
            +
                if is_curved:
         | 
| 483 | 
            +
                    return generate_pivot_list_curved(
         | 
| 484 | 
            +
                        p_score,
         | 
| 485 | 
            +
                        p_char_maps,
         | 
| 486 | 
            +
                        f_direction,
         | 
| 487 | 
            +
                        score_thresh=score_thresh,
         | 
| 488 | 
            +
                        is_expand=True,
         | 
| 489 | 
            +
                        is_backbone=is_backbone,
         | 
| 490 | 
            +
                        image_id=image_id,
         | 
| 491 | 
            +
                    )
         | 
| 492 | 
            +
                else:
         | 
| 493 | 
            +
                    return generate_pivot_list_horizontal(
         | 
| 494 | 
            +
                        p_score,
         | 
| 495 | 
            +
                        p_char_maps,
         | 
| 496 | 
            +
                        f_direction,
         | 
| 497 | 
            +
                        score_thresh=score_thresh,
         | 
| 498 | 
            +
                        is_backbone=is_backbone,
         | 
| 499 | 
            +
                        image_id=image_id,
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
            # for refine module
         | 
| 504 | 
            +
            def extract_main_direction(pos_list, f_direction):
         | 
| 505 | 
            +
                """
         | 
| 506 | 
            +
                f_direction: h x w x 2
         | 
| 507 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 508 | 
            +
                """
         | 
| 509 | 
            +
                pos_list = np.array(pos_list)
         | 
| 510 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
         | 
| 511 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 512 | 
            +
                average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 513 | 
            +
                average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
         | 
| 514 | 
            +
                return average_direction
         | 
| 515 | 
            +
             | 
| 516 | 
            +
             | 
| 517 | 
            +
            def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
         | 
| 518 | 
            +
                """
         | 
| 519 | 
            +
                f_direction: h x w x 2
         | 
| 520 | 
            +
                pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
         | 
| 521 | 
            +
                """
         | 
| 522 | 
            +
                pos_list_full = np.array(pos_list).reshape(-1, 3)
         | 
| 523 | 
            +
                pos_list = pos_list_full[:, 1:]
         | 
| 524 | 
            +
                point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]  # x, y
         | 
| 525 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 526 | 
            +
                average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 527 | 
            +
                pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 528 | 
            +
                sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
         | 
| 529 | 
            +
                return sorted_list
         | 
| 530 | 
            +
             | 
| 531 | 
            +
             | 
| 532 | 
            +
            def sort_by_direction_with_image_id(pos_list, f_direction):
         | 
| 533 | 
            +
                """
         | 
| 534 | 
            +
                f_direction: h x w x 2
         | 
| 535 | 
            +
                pos_list: [[y, x], [y, x], [y, x] ...]
         | 
| 536 | 
            +
                """
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def sort_part_with_direction(pos_list_full, point_direction):
         | 
| 539 | 
            +
                    pos_list_full = np.array(pos_list_full).reshape(-1, 3)
         | 
| 540 | 
            +
                    pos_list = pos_list_full[:, 1:]
         | 
| 541 | 
            +
                    point_direction = np.array(point_direction).reshape(-1, 2)
         | 
| 542 | 
            +
                    average_direction = np.mean(point_direction, axis=0, keepdims=True)
         | 
| 543 | 
            +
                    pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
         | 
| 544 | 
            +
                    sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
         | 
| 545 | 
            +
                    sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
         | 
| 546 | 
            +
                    return sorted_list, sorted_direction
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                pos_list = np.array(pos_list).reshape(-1, 3)
         | 
| 549 | 
            +
                point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]]  # x, y
         | 
| 550 | 
            +
                point_direction = point_direction[:, ::-1]  # x, y -> y, x
         | 
| 551 | 
            +
                sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                point_num = len(sorted_point)
         | 
| 554 | 
            +
                if point_num >= 16:
         | 
| 555 | 
            +
                    middle_num = point_num // 2
         | 
| 556 | 
            +
                    first_part_point = sorted_point[:middle_num]
         | 
| 557 | 
            +
                    first_point_direction = sorted_direction[:middle_num]
         | 
| 558 | 
            +
                    sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
         | 
| 559 | 
            +
                        first_part_point, first_point_direction
         | 
| 560 | 
            +
                    )
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    last_part_point = sorted_point[middle_num:]
         | 
| 563 | 
            +
                    last_point_direction = sorted_direction[middle_num:]
         | 
| 564 | 
            +
                    sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
         | 
| 565 | 
            +
                        last_part_point, last_point_direction
         | 
| 566 | 
            +
                    )
         | 
| 567 | 
            +
                    sorted_point = sorted_fist_part_point + sorted_last_part_point
         | 
| 568 | 
            +
                    sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                return sorted_point
         | 
| 571 | 
            +
             | 
| 572 | 
            +
             | 
| 573 | 
            +
            def generate_pivot_list_tt_inference(
         | 
| 574 | 
            +
                p_score,
         | 
| 575 | 
            +
                p_char_maps,
         | 
| 576 | 
            +
                f_direction,
         | 
| 577 | 
            +
                score_thresh=0.5,
         | 
| 578 | 
            +
                is_backbone=False,
         | 
| 579 | 
            +
                is_curved=True,
         | 
| 580 | 
            +
                image_id=0,
         | 
| 581 | 
            +
            ):
         | 
| 582 | 
            +
                """
         | 
| 583 | 
            +
                return center point and end point of TCL instance; filter with the char maps;
         | 
| 584 | 
            +
                """
         | 
| 585 | 
            +
                p_score = p_score[0]
         | 
| 586 | 
            +
                f_direction = f_direction.transpose(1, 2, 0)
         | 
| 587 | 
            +
                p_tcl_map = (p_score > score_thresh) * 1.0
         | 
| 588 | 
            +
                skeleton_map = thin(p_tcl_map)
         | 
| 589 | 
            +
                instance_count, instance_label_map = cv2.connectedComponents(
         | 
| 590 | 
            +
                    skeleton_map.astype(np.uint8), connectivity=8
         | 
| 591 | 
            +
                )
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                # get TCL Instance
         | 
| 594 | 
            +
                all_pos_yxs = []
         | 
| 595 | 
            +
                if instance_count > 0:
         | 
| 596 | 
            +
                    for instance_id in range(1, instance_count):
         | 
| 597 | 
            +
                        pos_list = []
         | 
| 598 | 
            +
                        ys, xs = np.where(instance_label_map == instance_id)
         | 
| 599 | 
            +
                        pos_list = list(zip(ys, xs))
         | 
| 600 | 
            +
                        ### FIX-ME, eliminate outlier
         | 
| 601 | 
            +
                        if len(pos_list) < 3:
         | 
| 602 | 
            +
                            continue
         | 
| 603 | 
            +
                        pos_list_sorted = sort_and_expand_with_direction_v2(
         | 
| 604 | 
            +
                            pos_list, f_direction, p_tcl_map
         | 
| 605 | 
            +
                        )
         | 
| 606 | 
            +
                        pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
         | 
| 607 | 
            +
                        all_pos_yxs.append(pos_list_sorted_with_id)
         | 
| 608 | 
            +
                return all_pos_yxs
         | 
    	
        ocr/postprocess/fce_postprocess.py
    ADDED
    
    | @@ -0,0 +1,234 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import paddle
         | 
| 4 | 
            +
            from numpy.fft import ifft
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .poly_nms import *
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def fill_hole(input_mask):
         | 
| 10 | 
            +
                h, w = input_mask.shape
         | 
| 11 | 
            +
                canvas = np.zeros((h + 2, w + 2), np.uint8)
         | 
| 12 | 
            +
                canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                mask = np.zeros((h + 4, w + 4), np.uint8)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                cv2.floodFill(canvas, mask, (0, 0), 1)
         | 
| 17 | 
            +
                canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                return ~canvas | input_mask
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def fourier2poly(fourier_coeff, num_reconstr_points=50):
         | 
| 23 | 
            +
                """Inverse Fourier transform
         | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
         | 
| 26 | 
            +
                        with n and k being candidates number and Fourier degree
         | 
| 27 | 
            +
                        respectively.
         | 
| 28 | 
            +
                    num_reconstr_points (int): Number of reconstructed polygon points.
         | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    Polygons (ndarray): The reconstructed polygons shaped (n, n')
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype="complex")
         | 
| 34 | 
            +
                k = (len(fourier_coeff[0]) - 1) // 2
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                a[:, 0 : k + 1] = fourier_coeff[:, k:]
         | 
| 37 | 
            +
                a[:, -k:] = fourier_coeff[:, :k]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                poly_complex = ifft(a) * num_reconstr_points
         | 
| 40 | 
            +
                polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
         | 
| 41 | 
            +
                polygon[:, :, 0] = poly_complex.real
         | 
| 42 | 
            +
                polygon[:, :, 1] = poly_complex.imag
         | 
| 43 | 
            +
                return polygon.astype("int32").reshape((len(fourier_coeff), -1))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class FCEPostProcess(object):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                The post process for FCENet.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(
         | 
| 52 | 
            +
                    self,
         | 
| 53 | 
            +
                    scales,
         | 
| 54 | 
            +
                    fourier_degree=5,
         | 
| 55 | 
            +
                    num_reconstr_points=50,
         | 
| 56 | 
            +
                    decoding_type="fcenet",
         | 
| 57 | 
            +
                    score_thr=0.3,
         | 
| 58 | 
            +
                    nms_thr=0.1,
         | 
| 59 | 
            +
                    alpha=1.0,
         | 
| 60 | 
            +
                    beta=1.0,
         | 
| 61 | 
            +
                    box_type="poly",
         | 
| 62 | 
            +
                    **kwargs
         | 
| 63 | 
            +
                ):
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.scales = scales
         | 
| 66 | 
            +
                    self.fourier_degree = fourier_degree
         | 
| 67 | 
            +
                    self.num_reconstr_points = num_reconstr_points
         | 
| 68 | 
            +
                    self.decoding_type = decoding_type
         | 
| 69 | 
            +
                    self.score_thr = score_thr
         | 
| 70 | 
            +
                    self.nms_thr = nms_thr
         | 
| 71 | 
            +
                    self.alpha = alpha
         | 
| 72 | 
            +
                    self.beta = beta
         | 
| 73 | 
            +
                    self.box_type = box_type
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __call__(self, preds, shape_list):
         | 
| 76 | 
            +
                    score_maps = []
         | 
| 77 | 
            +
                    for key, value in preds.items():
         | 
| 78 | 
            +
                        if isinstance(value, paddle.Tensor):
         | 
| 79 | 
            +
                            value = value.numpy()
         | 
| 80 | 
            +
                        cls_res = value[:, :4, :, :]
         | 
| 81 | 
            +
                        reg_res = value[:, 4:, :, :]
         | 
| 82 | 
            +
                        score_maps.append([cls_res, reg_res])
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return self.get_boundary(score_maps, shape_list)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def resize_boundary(self, boundaries, scale_factor):
         | 
| 87 | 
            +
                    """Rescale boundaries via scale_factor.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    Args:
         | 
| 90 | 
            +
                        boundaries (list[list[float]]): The boundary list. Each boundary
         | 
| 91 | 
            +
                        with size 2k+1 with k>=4.
         | 
| 92 | 
            +
                        scale_factor(ndarray): The scale factor of size (4,).
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    Returns:
         | 
| 95 | 
            +
                        boundaries (list[list[float]]): The scaled boundaries.
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    boxes = []
         | 
| 98 | 
            +
                    scores = []
         | 
| 99 | 
            +
                    for b in boundaries:
         | 
| 100 | 
            +
                        sz = len(b)
         | 
| 101 | 
            +
                        valid_boundary(b, True)
         | 
| 102 | 
            +
                        scores.append(b[-1])
         | 
| 103 | 
            +
                        b = (
         | 
| 104 | 
            +
                            (
         | 
| 105 | 
            +
                                np.array(b[: sz - 1])
         | 
| 106 | 
            +
                                * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
         | 
| 107 | 
            +
                            )
         | 
| 108 | 
            +
                            .flatten()
         | 
| 109 | 
            +
                            .tolist()
         | 
| 110 | 
            +
                        )
         | 
| 111 | 
            +
                        boxes.append(np.array(b).reshape([-1, 2]))
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    return np.array(boxes, dtype=np.float32), scores
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def get_boundary(self, score_maps, shape_list):
         | 
| 116 | 
            +
                    assert len(score_maps) == len(self.scales)
         | 
| 117 | 
            +
                    boundaries = []
         | 
| 118 | 
            +
                    for idx, score_map in enumerate(score_maps):
         | 
| 119 | 
            +
                        scale = self.scales[idx]
         | 
| 120 | 
            +
                        boundaries = boundaries + self._get_boundary_single(score_map, scale)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # nms
         | 
| 123 | 
            +
                    boundaries = poly_nms(boundaries, self.nms_thr)
         | 
| 124 | 
            +
                    boundaries, scores = self.resize_boundary(
         | 
| 125 | 
            +
                        boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    boxes_batch = [dict(points=boundaries, scores=scores)]
         | 
| 129 | 
            +
                    return boxes_batch
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def _get_boundary_single(self, score_map, scale):
         | 
| 132 | 
            +
                    assert len(score_map) == 2
         | 
| 133 | 
            +
                    assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    return self.fcenet_decode(
         | 
| 136 | 
            +
                        preds=score_map,
         | 
| 137 | 
            +
                        fourier_degree=self.fourier_degree,
         | 
| 138 | 
            +
                        num_reconstr_points=self.num_reconstr_points,
         | 
| 139 | 
            +
                        scale=scale,
         | 
| 140 | 
            +
                        alpha=self.alpha,
         | 
| 141 | 
            +
                        beta=self.beta,
         | 
| 142 | 
            +
                        box_type=self.box_type,
         | 
| 143 | 
            +
                        score_thr=self.score_thr,
         | 
| 144 | 
            +
                        nms_thr=self.nms_thr,
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def fcenet_decode(
         | 
| 148 | 
            +
                    self,
         | 
| 149 | 
            +
                    preds,
         | 
| 150 | 
            +
                    fourier_degree,
         | 
| 151 | 
            +
                    num_reconstr_points,
         | 
| 152 | 
            +
                    scale,
         | 
| 153 | 
            +
                    alpha=1.0,
         | 
| 154 | 
            +
                    beta=2.0,
         | 
| 155 | 
            +
                    box_type="poly",
         | 
| 156 | 
            +
                    score_thr=0.3,
         | 
| 157 | 
            +
                    nms_thr=0.1,
         | 
| 158 | 
            +
                ):
         | 
| 159 | 
            +
                    """Decoding predictions of FCENet to instances.
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    Args:
         | 
| 162 | 
            +
                        preds (list(Tensor)): The head output tensors.
         | 
| 163 | 
            +
                        fourier_degree (int): The maximum Fourier transform degree k.
         | 
| 164 | 
            +
                        num_reconstr_points (int): The points number of the polygon
         | 
| 165 | 
            +
                            reconstructed from predicted Fourier coefficients.
         | 
| 166 | 
            +
                        scale (int): The down-sample scale of the prediction.
         | 
| 167 | 
            +
                        alpha (float) : The parameter to calculate final scores. Score_{final}
         | 
| 168 | 
            +
                                = (Score_{text region} ^ alpha)
         | 
| 169 | 
            +
                                * (Score_{text center region}^ beta)
         | 
| 170 | 
            +
                        beta (float) : The parameter to calculate final score.
         | 
| 171 | 
            +
                        box_type (str):  Boundary encoding type 'poly' or 'quad'.
         | 
| 172 | 
            +
                        score_thr (float) : The threshold used to filter out the final
         | 
| 173 | 
            +
                            candidates.
         | 
| 174 | 
            +
                        nms_thr (float) :  The threshold of nms.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    Returns:
         | 
| 177 | 
            +
                        boundaries (list[list[float]]): The instance boundary and confidence
         | 
| 178 | 
            +
                            list.
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    assert isinstance(preds, list)
         | 
| 181 | 
            +
                    assert len(preds) == 2
         | 
| 182 | 
            +
                    assert box_type in ["poly", "quad"]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    cls_pred = preds[0][0]
         | 
| 185 | 
            +
                    tr_pred = cls_pred[0:2]
         | 
| 186 | 
            +
                    tcl_pred = cls_pred[2:]
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    reg_pred = preds[1][0].transpose([1, 2, 0])
         | 
| 189 | 
            +
                    x_pred = reg_pred[:, :, : 2 * fourier_degree + 1]
         | 
| 190 | 
            +
                    y_pred = reg_pred[:, :, 2 * fourier_degree + 1 :]
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    score_pred = (tr_pred[1] ** alpha) * (tcl_pred[1] ** beta)
         | 
| 193 | 
            +
                    tr_pred_mask = (score_pred) > score_thr
         | 
| 194 | 
            +
                    tr_mask = fill_hole(tr_pred_mask)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    tr_contours, _ = cv2.findContours(
         | 
| 197 | 
            +
                        tr_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
         | 
| 198 | 
            +
                    )  # opencv4
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    mask = np.zeros_like(tr_mask)
         | 
| 201 | 
            +
                    boundaries = []
         | 
| 202 | 
            +
                    for cont in tr_contours:
         | 
| 203 | 
            +
                        deal_map = mask.copy().astype(np.int8)
         | 
| 204 | 
            +
                        cv2.drawContours(deal_map, [cont], -1, 1, -1)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        score_map = score_pred * deal_map
         | 
| 207 | 
            +
                        score_mask = score_map > 0
         | 
| 208 | 
            +
                        xy_text = np.argwhere(score_mask)
         | 
| 209 | 
            +
                        dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                        x, y = x_pred[score_mask], y_pred[score_mask]
         | 
| 212 | 
            +
                        c = x + y * 1j
         | 
| 213 | 
            +
                        c[:, fourier_degree] = c[:, fourier_degree] + dxy
         | 
| 214 | 
            +
                        c *= scale
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                        polygons = fourier2poly(c, num_reconstr_points)
         | 
| 217 | 
            +
                        score = score_map[score_mask].reshape(-1, 1)
         | 
| 218 | 
            +
                        polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        boundaries = boundaries + polygons
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    boundaries = poly_nms(boundaries, nms_thr)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    if box_type == "quad":
         | 
| 225 | 
            +
                        new_boundaries = []
         | 
| 226 | 
            +
                        for boundary in boundaries:
         | 
| 227 | 
            +
                            poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
         | 
| 228 | 
            +
                            score = boundary[-1]
         | 
| 229 | 
            +
                            points = cv2.boxPoints(cv2.minAreaRect(poly))
         | 
| 230 | 
            +
                            points = np.int0(points)
         | 
| 231 | 
            +
                            new_boundaries.append(points.reshape(-1).tolist() + [score])
         | 
| 232 | 
            +
                            boundaries = new_boundaries
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    return boundaries
         | 
    	
        ocr/postprocess/locality_aware_nms.py
    ADDED
    
    | @@ -0,0 +1,198 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Locality aware nms.
         | 
| 3 | 
            +
            This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from shapely.geometry import Polygon
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def intersection(g, p):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                Intersection.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                g = Polygon(g[:8].reshape((4, 2)))
         | 
| 15 | 
            +
                p = Polygon(p[:8].reshape((4, 2)))
         | 
| 16 | 
            +
                g = g.buffer(0)
         | 
| 17 | 
            +
                p = p.buffer(0)
         | 
| 18 | 
            +
                if not g.is_valid or not p.is_valid:
         | 
| 19 | 
            +
                    return 0
         | 
| 20 | 
            +
                inter = Polygon(g).intersection(Polygon(p)).area
         | 
| 21 | 
            +
                union = g.area + p.area - inter
         | 
| 22 | 
            +
                if union == 0:
         | 
| 23 | 
            +
                    return 0
         | 
| 24 | 
            +
                else:
         | 
| 25 | 
            +
                    return inter / union
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def intersection_iog(g, p):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Intersection_iog.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                g = Polygon(g[:8].reshape((4, 2)))
         | 
| 33 | 
            +
                p = Polygon(p[:8].reshape((4, 2)))
         | 
| 34 | 
            +
                if not g.is_valid or not p.is_valid:
         | 
| 35 | 
            +
                    return 0
         | 
| 36 | 
            +
                inter = Polygon(g).intersection(Polygon(p)).area
         | 
| 37 | 
            +
                # union = g.area + p.area - inter
         | 
| 38 | 
            +
                union = p.area
         | 
| 39 | 
            +
                if union == 0:
         | 
| 40 | 
            +
                    print("p_area is very small")
         | 
| 41 | 
            +
                    return 0
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    return inter / union
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def weighted_merge(g, p):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Weighted merge.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
         | 
| 51 | 
            +
                g[8] = g[8] + p[8]
         | 
| 52 | 
            +
                return g
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def standard_nms(S, thres):
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Standard nms.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                order = np.argsort(S[:, 8])[::-1]
         | 
| 60 | 
            +
                keep = []
         | 
| 61 | 
            +
                while order.size > 0:
         | 
| 62 | 
            +
                    i = order[0]
         | 
| 63 | 
            +
                    keep.append(i)
         | 
| 64 | 
            +
                    ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    inds = np.where(ovr <= thres)[0]
         | 
| 67 | 
            +
                    order = order[inds + 1]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                return S[keep]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def standard_nms_inds(S, thres):
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                Standard nms, retun inds.
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                order = np.argsort(S[:, 8])[::-1]
         | 
| 77 | 
            +
                keep = []
         | 
| 78 | 
            +
                while order.size > 0:
         | 
| 79 | 
            +
                    i = order[0]
         | 
| 80 | 
            +
                    keep.append(i)
         | 
| 81 | 
            +
                    ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    inds = np.where(ovr <= thres)[0]
         | 
| 84 | 
            +
                    order = order[inds + 1]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return keep
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def nms(S, thres):
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                nms.
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
                order = np.argsort(S[:, 8])[::-1]
         | 
| 94 | 
            +
                keep = []
         | 
| 95 | 
            +
                while order.size > 0:
         | 
| 96 | 
            +
                    i = order[0]
         | 
| 97 | 
            +
                    keep.append(i)
         | 
| 98 | 
            +
                    ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    inds = np.where(ovr <= thres)[0]
         | 
| 101 | 
            +
                    order = order[inds + 1]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                return keep
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                soft_nms
         | 
| 109 | 
            +
                :para boxes_in, N x 9 (coords + score)
         | 
| 110 | 
            +
                :para threshould, eliminate cases min score(0.001)
         | 
| 111 | 
            +
                :para Nt_thres, iou_threshi
         | 
| 112 | 
            +
                :para sigma, gaussian weght
         | 
| 113 | 
            +
                :method, linear or gaussian
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                boxes = boxes_in.copy()
         | 
| 116 | 
            +
                N = boxes.shape[0]
         | 
| 117 | 
            +
                if N is None or N < 1:
         | 
| 118 | 
            +
                    return np.array([])
         | 
| 119 | 
            +
                pos, maxpos = 0, 0
         | 
| 120 | 
            +
                weight = 0.0
         | 
| 121 | 
            +
                inds = np.arange(N)
         | 
| 122 | 
            +
                tbox, sbox = boxes[0].copy(), boxes[0].copy()
         | 
| 123 | 
            +
                for i in range(N):
         | 
| 124 | 
            +
                    maxscore = boxes[i, 8]
         | 
| 125 | 
            +
                    maxpos = i
         | 
| 126 | 
            +
                    tbox = boxes[i].copy()
         | 
| 127 | 
            +
                    ti = inds[i]
         | 
| 128 | 
            +
                    pos = i + 1
         | 
| 129 | 
            +
                    # get max box
         | 
| 130 | 
            +
                    while pos < N:
         | 
| 131 | 
            +
                        if maxscore < boxes[pos, 8]:
         | 
| 132 | 
            +
                            maxscore = boxes[pos, 8]
         | 
| 133 | 
            +
                            maxpos = pos
         | 
| 134 | 
            +
                        pos = pos + 1
         | 
| 135 | 
            +
                    # add max box as a detection
         | 
| 136 | 
            +
                    boxes[i, :] = boxes[maxpos, :]
         | 
| 137 | 
            +
                    inds[i] = inds[maxpos]
         | 
| 138 | 
            +
                    # swap
         | 
| 139 | 
            +
                    boxes[maxpos, :] = tbox
         | 
| 140 | 
            +
                    inds[maxpos] = ti
         | 
| 141 | 
            +
                    tbox = boxes[i].copy()
         | 
| 142 | 
            +
                    pos = i + 1
         | 
| 143 | 
            +
                    # NMS iteration
         | 
| 144 | 
            +
                    while pos < N:
         | 
| 145 | 
            +
                        sbox = boxes[pos].copy()
         | 
| 146 | 
            +
                        ts_iou_val = intersection(tbox, sbox)
         | 
| 147 | 
            +
                        if ts_iou_val > 0:
         | 
| 148 | 
            +
                            if method == 1:
         | 
| 149 | 
            +
                                if ts_iou_val > Nt_thres:
         | 
| 150 | 
            +
                                    weight = 1 - ts_iou_val
         | 
| 151 | 
            +
                                else:
         | 
| 152 | 
            +
                                    weight = 1
         | 
| 153 | 
            +
                            elif method == 2:
         | 
| 154 | 
            +
                                weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
         | 
| 155 | 
            +
                            else:
         | 
| 156 | 
            +
                                if ts_iou_val > Nt_thres:
         | 
| 157 | 
            +
                                    weight = 0
         | 
| 158 | 
            +
                                else:
         | 
| 159 | 
            +
                                    weight = 1
         | 
| 160 | 
            +
                            boxes[pos, 8] = weight * boxes[pos, 8]
         | 
| 161 | 
            +
                            # if box score falls below thresold, discard the box by
         | 
| 162 | 
            +
                            # swaping last box update N
         | 
| 163 | 
            +
                            if boxes[pos, 8] < threshold:
         | 
| 164 | 
            +
                                boxes[pos, :] = boxes[N - 1, :]
         | 
| 165 | 
            +
                                inds[pos] = inds[N - 1]
         | 
| 166 | 
            +
                                N = N - 1
         | 
| 167 | 
            +
                                pos = pos - 1
         | 
| 168 | 
            +
                        pos = pos + 1
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                return boxes[:N]
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def nms_locality(polys, thres=0.3):
         | 
| 174 | 
            +
                """
         | 
| 175 | 
            +
                locality aware nms of EAST
         | 
| 176 | 
            +
                :param polys: a N*9 numpy array. first 8 coordinates, then prob
         | 
| 177 | 
            +
                :return: boxes after nms
         | 
| 178 | 
            +
                """
         | 
| 179 | 
            +
                S = []
         | 
| 180 | 
            +
                p = None
         | 
| 181 | 
            +
                for g in polys:
         | 
| 182 | 
            +
                    if p is not None and intersection(g, p) > thres:
         | 
| 183 | 
            +
                        p = weighted_merge(g, p)
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        if p is not None:
         | 
| 186 | 
            +
                            S.append(p)
         | 
| 187 | 
            +
                        p = g
         | 
| 188 | 
            +
                if p is not None:
         | 
| 189 | 
            +
                    S.append(p)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                if len(S) == 0:
         | 
| 192 | 
            +
                    return np.array([])
         | 
| 193 | 
            +
                return standard_nms(np.array(S), thres)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            if __name__ == "__main__":
         | 
| 197 | 
            +
                # 343,350,448,135,474,143,369,359
         | 
| 198 | 
            +
                print(Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])).area)
         | 
    	
        ocr/postprocess/pg_postprocess.py
    ADDED
    
    | @@ -0,0 +1,189 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import paddle
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .extract_textpoint_fast import *
         | 
| 9 | 
            +
            from .extract_textpoint_slow import *
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            __dir__ = os.path.dirname(__file__)
         | 
| 12 | 
            +
            sys.path.append(__dir__)
         | 
| 13 | 
            +
            sys.path.append(os.path.join(__dir__, ".."))
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class PGNet_PostProcess(object):
         | 
| 17 | 
            +
                # two different post-process
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self, character_dict_path, valid_set, score_thresh, outs_dict, shape_list
         | 
| 20 | 
            +
                ):
         | 
| 21 | 
            +
                    self.Lexicon_Table = get_dict(character_dict_path)
         | 
| 22 | 
            +
                    self.valid_set = valid_set
         | 
| 23 | 
            +
                    self.score_thresh = score_thresh
         | 
| 24 | 
            +
                    self.outs_dict = outs_dict
         | 
| 25 | 
            +
                    self.shape_list = shape_list
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def pg_postprocess_fast(self):
         | 
| 28 | 
            +
                    p_score = self.outs_dict["f_score"]
         | 
| 29 | 
            +
                    p_border = self.outs_dict["f_border"]
         | 
| 30 | 
            +
                    p_char = self.outs_dict["f_char"]
         | 
| 31 | 
            +
                    p_direction = self.outs_dict["f_direction"]
         | 
| 32 | 
            +
                    if isinstance(p_score, paddle.Tensor):
         | 
| 33 | 
            +
                        p_score = p_score[0].numpy()
         | 
| 34 | 
            +
                        p_border = p_border[0].numpy()
         | 
| 35 | 
            +
                        p_direction = p_direction[0].numpy()
         | 
| 36 | 
            +
                        p_char = p_char[0].numpy()
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        p_score = p_score[0]
         | 
| 39 | 
            +
                        p_border = p_border[0]
         | 
| 40 | 
            +
                        p_direction = p_direction[0]
         | 
| 41 | 
            +
                        p_char = p_char[0]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
         | 
| 44 | 
            +
                    instance_yxs_list, seq_strs = generate_pivot_list_fast(
         | 
| 45 | 
            +
                        p_score,
         | 
| 46 | 
            +
                        p_char,
         | 
| 47 | 
            +
                        p_direction,
         | 
| 48 | 
            +
                        self.Lexicon_Table,
         | 
| 49 | 
            +
                        score_thresh=self.score_thresh,
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    poly_list, keep_str_list = restore_poly(
         | 
| 52 | 
            +
                        instance_yxs_list,
         | 
| 53 | 
            +
                        seq_strs,
         | 
| 54 | 
            +
                        p_border,
         | 
| 55 | 
            +
                        ratio_w,
         | 
| 56 | 
            +
                        ratio_h,
         | 
| 57 | 
            +
                        src_w,
         | 
| 58 | 
            +
                        src_h,
         | 
| 59 | 
            +
                        self.valid_set,
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    data = {
         | 
| 62 | 
            +
                        "points": poly_list,
         | 
| 63 | 
            +
                        "texts": keep_str_list,
         | 
| 64 | 
            +
                    }
         | 
| 65 | 
            +
                    return data
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def pg_postprocess_slow(self):
         | 
| 68 | 
            +
                    p_score = self.outs_dict["f_score"]
         | 
| 69 | 
            +
                    p_border = self.outs_dict["f_border"]
         | 
| 70 | 
            +
                    p_char = self.outs_dict["f_char"]
         | 
| 71 | 
            +
                    p_direction = self.outs_dict["f_direction"]
         | 
| 72 | 
            +
                    if isinstance(p_score, paddle.Tensor):
         | 
| 73 | 
            +
                        p_score = p_score[0].numpy()
         | 
| 74 | 
            +
                        p_border = p_border[0].numpy()
         | 
| 75 | 
            +
                        p_direction = p_direction[0].numpy()
         | 
| 76 | 
            +
                        p_char = p_char[0].numpy()
         | 
| 77 | 
            +
                    else:
         | 
| 78 | 
            +
                        p_score = p_score[0]
         | 
| 79 | 
            +
                        p_border = p_border[0]
         | 
| 80 | 
            +
                        p_direction = p_direction[0]
         | 
| 81 | 
            +
                        p_char = p_char[0]
         | 
| 82 | 
            +
                    src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
         | 
| 83 | 
            +
                    is_curved = self.valid_set == "totaltext"
         | 
| 84 | 
            +
                    char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
         | 
| 85 | 
            +
                        p_score,
         | 
| 86 | 
            +
                        p_char,
         | 
| 87 | 
            +
                        p_direction,
         | 
| 88 | 
            +
                        score_thresh=self.score_thresh,
         | 
| 89 | 
            +
                        is_backbone=True,
         | 
| 90 | 
            +
                        is_curved=is_curved,
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    seq_strs = []
         | 
| 93 | 
            +
                    for char_idx_set in char_seq_idx_set:
         | 
| 94 | 
            +
                        pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
         | 
| 95 | 
            +
                        seq_strs.append(pr_str)
         | 
| 96 | 
            +
                    poly_list = []
         | 
| 97 | 
            +
                    keep_str_list = []
         | 
| 98 | 
            +
                    all_point_list = []
         | 
| 99 | 
            +
                    all_point_pair_list = []
         | 
| 100 | 
            +
                    for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
         | 
| 101 | 
            +
                        if len(yx_center_line) == 1:
         | 
| 102 | 
            +
                            yx_center_line.append(yx_center_line[-1])
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        offset_expand = 1.0
         | 
| 105 | 
            +
                        if self.valid_set == "totaltext":
         | 
| 106 | 
            +
                            offset_expand = 1.2
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        point_pair_list = []
         | 
| 109 | 
            +
                        for batch_id, y, x in yx_center_line:
         | 
| 110 | 
            +
                            offset = p_border[:, y, x].reshape(2, 2)
         | 
| 111 | 
            +
                            if offset_expand != 1.0:
         | 
| 112 | 
            +
                                offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
         | 
| 113 | 
            +
                                expand_length = np.clip(
         | 
| 114 | 
            +
                                    offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
         | 
| 115 | 
            +
                                )
         | 
| 116 | 
            +
                                offset_detal = offset / offset_length * expand_length
         | 
| 117 | 
            +
                                offset = offset + offset_detal
         | 
| 118 | 
            +
                            ori_yx = np.array([y, x], dtype=np.float32)
         | 
| 119 | 
            +
                            point_pair = (
         | 
| 120 | 
            +
                                (ori_yx + offset)[:, ::-1]
         | 
| 121 | 
            +
                                * 4.0
         | 
| 122 | 
            +
                                / np.array([ratio_w, ratio_h]).reshape(-1, 2)
         | 
| 123 | 
            +
                            )
         | 
| 124 | 
            +
                            point_pair_list.append(point_pair)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                            all_point_list.append(
         | 
| 127 | 
            +
                                [int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))]
         | 
| 128 | 
            +
                            )
         | 
| 129 | 
            +
                            all_point_pair_list.append(point_pair.round().astype(np.int32).tolist())
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        detected_poly, pair_length_info = point_pair2poly(point_pair_list)
         | 
| 132 | 
            +
                        detected_poly = expand_poly_along_width(
         | 
| 133 | 
            +
                            detected_poly, shrink_ratio_of_width=0.2
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                        detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
         | 
| 136 | 
            +
                        detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        if len(keep_str) < 2:
         | 
| 139 | 
            +
                            continue
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        keep_str_list.append(keep_str)
         | 
| 142 | 
            +
                        detected_poly = np.round(detected_poly).astype("int32")
         | 
| 143 | 
            +
                        if self.valid_set == "partvgg":
         | 
| 144 | 
            +
                            middle_point = len(detected_poly) // 2
         | 
| 145 | 
            +
                            detected_poly = detected_poly[
         | 
| 146 | 
            +
                                [0, middle_point - 1, middle_point, -1], :
         | 
| 147 | 
            +
                            ]
         | 
| 148 | 
            +
                            poly_list.append(detected_poly)
         | 
| 149 | 
            +
                        elif self.valid_set == "totaltext":
         | 
| 150 | 
            +
                            poly_list.append(detected_poly)
         | 
| 151 | 
            +
                        else:
         | 
| 152 | 
            +
                            print("--> Not supported format.")
         | 
| 153 | 
            +
                            exit(-1)
         | 
| 154 | 
            +
                    data = {
         | 
| 155 | 
            +
                        "points": poly_list,
         | 
| 156 | 
            +
                        "texts": keep_str_list,
         | 
| 157 | 
            +
                    }
         | 
| 158 | 
            +
                    return data
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            class PGPostProcess(object):
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                The post process for PGNet.
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def __init__(self, character_dict_path, valid_set, score_thresh, mode, **kwargs):
         | 
| 167 | 
            +
                    self.character_dict_path = character_dict_path
         | 
| 168 | 
            +
                    self.valid_set = valid_set
         | 
| 169 | 
            +
                    self.score_thresh = score_thresh
         | 
| 170 | 
            +
                    self.mode = mode
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # c++ la-nms is faster, but only support python 3.5
         | 
| 173 | 
            +
                    self.is_python35 = False
         | 
| 174 | 
            +
                    if sys.version_info.major == 3 and sys.version_info.minor == 5:
         | 
| 175 | 
            +
                        self.is_python35 = True
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def __call__(self, outs_dict, shape_list):
         | 
| 178 | 
            +
                    post = PGNet_PostProcess(
         | 
| 179 | 
            +
                        self.character_dict_path,
         | 
| 180 | 
            +
                        self.valid_set,
         | 
| 181 | 
            +
                        self.score_thresh,
         | 
| 182 | 
            +
                        outs_dict,
         | 
| 183 | 
            +
                        shape_list,
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
                    if self.mode == "fast":
         | 
| 186 | 
            +
                        data = post.pg_postprocess_fast()
         | 
| 187 | 
            +
                    else:
         | 
| 188 | 
            +
                        data = post.pg_postprocess_slow()
         | 
| 189 | 
            +
                    return data
         | 
    	
        ocr/postprocess/poly_nms.py
    ADDED
    
    | @@ -0,0 +1,132 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from shapely.geometry import Polygon
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def points2polygon(points):
         | 
| 6 | 
            +
                """Convert k points to 1 polygon.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    points (ndarray or list): A ndarray or a list of shape (2k)
         | 
| 10 | 
            +
                        that indicates k points.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Returns:
         | 
| 13 | 
            +
                    polygon (Polygon): A polygon object.
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                if isinstance(points, list):
         | 
| 16 | 
            +
                    points = np.array(points)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                assert isinstance(points, np.ndarray)
         | 
| 19 | 
            +
                assert (points.size % 2 == 0) and (points.size >= 8)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                point_mat = points.reshape([-1, 2])
         | 
| 22 | 
            +
                return Polygon(point_mat)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def poly_intersection(poly_det, poly_gt, buffer=0.0001):
         | 
| 26 | 
            +
                """Calculate the intersection area between two polygon.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                Args:
         | 
| 29 | 
            +
                    poly_det (Polygon): A polygon predicted by detector.
         | 
| 30 | 
            +
                    poly_gt (Polygon): A gt polygon.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Returns:
         | 
| 33 | 
            +
                    intersection_area (float): The intersection area between two polygons.
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                assert isinstance(poly_det, Polygon)
         | 
| 36 | 
            +
                assert isinstance(poly_gt, Polygon)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if buffer == 0:
         | 
| 39 | 
            +
                    poly_inter = poly_det & poly_gt
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
         | 
| 42 | 
            +
                return poly_inter.area, poly_inter
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def poly_union(poly_det, poly_gt):
         | 
| 46 | 
            +
                """Calculate the union area between two polygon.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                Args:
         | 
| 49 | 
            +
                    poly_det (Polygon): A polygon predicted by detector.
         | 
| 50 | 
            +
                    poly_gt (Polygon): A gt polygon.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                Returns:
         | 
| 53 | 
            +
                    union_area (float): The union area between two polygons.
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                assert isinstance(poly_det, Polygon)
         | 
| 56 | 
            +
                assert isinstance(poly_gt, Polygon)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                area_det = poly_det.area
         | 
| 59 | 
            +
                area_gt = poly_gt.area
         | 
| 60 | 
            +
                area_inters, _ = poly_intersection(poly_det, poly_gt)
         | 
| 61 | 
            +
                return area_det + area_gt - area_inters
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def valid_boundary(x, with_score=True):
         | 
| 65 | 
            +
                num = len(x)
         | 
| 66 | 
            +
                if num < 8:
         | 
| 67 | 
            +
                    return False
         | 
| 68 | 
            +
                if num % 2 == 0 and (not with_score):
         | 
| 69 | 
            +
                    return True
         | 
| 70 | 
            +
                if num % 2 == 1 and with_score:
         | 
| 71 | 
            +
                    return True
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                return False
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def boundary_iou(src, target):
         | 
| 77 | 
            +
                """Calculate the IOU between two boundaries.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                Args:
         | 
| 80 | 
            +
                   src (list): Source boundary.
         | 
| 81 | 
            +
                   target (list): Target boundary.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Returns:
         | 
| 84 | 
            +
                   iou (float): The iou between two boundaries.
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                assert valid_boundary(src, False)
         | 
| 87 | 
            +
                assert valid_boundary(target, False)
         | 
| 88 | 
            +
                src_poly = points2polygon(src)
         | 
| 89 | 
            +
                target_poly = points2polygon(target)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                return poly_iou(src_poly, target_poly)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def poly_iou(poly_det, poly_gt):
         | 
| 95 | 
            +
                """Calculate the IOU between two polygons.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                Args:
         | 
| 98 | 
            +
                    poly_det (Polygon): A polygon predicted by detector.
         | 
| 99 | 
            +
                    poly_gt (Polygon): A gt polygon.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                Returns:
         | 
| 102 | 
            +
                    iou (float): The IOU between two polygons.
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                assert isinstance(poly_det, Polygon)
         | 
| 105 | 
            +
                assert isinstance(poly_gt, Polygon)
         | 
| 106 | 
            +
                area_inters, _ = poly_intersection(poly_det, poly_gt)
         | 
| 107 | 
            +
                area_union = poly_union(poly_det, poly_gt)
         | 
| 108 | 
            +
                if area_union == 0:
         | 
| 109 | 
            +
                    return 0.0
         | 
| 110 | 
            +
                return area_inters / area_union
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def poly_nms(polygons, threshold):
         | 
| 114 | 
            +
                assert isinstance(polygons, list)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                keep_poly = []
         | 
| 119 | 
            +
                index = [i for i in range(polygons.shape[0])]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                while len(index) > 0:
         | 
| 122 | 
            +
                    keep_poly.append(polygons[index[-1]].tolist())
         | 
| 123 | 
            +
                    A = polygons[index[-1]][:-1]
         | 
| 124 | 
            +
                    index = np.delete(index, -1)
         | 
| 125 | 
            +
                    iou_list = np.zeros((len(index),))
         | 
| 126 | 
            +
                    for i in range(len(index)):
         | 
| 127 | 
            +
                        B = polygons[index[i]][:-1]
         | 
| 128 | 
            +
                        iou_list[i] = boundary_iou(A, B)
         | 
| 129 | 
            +
                    remove_index = np.where(iou_list > threshold)
         | 
| 130 | 
            +
                    index = np.delete(index, remove_index)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                return keep_poly
         | 
    	
        ocr/postprocess/pse_postprocess/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .pse_postprocess import PSEPostProcess
         | 
    	
        ocr/postprocess/pse_postprocess/pse/__init__.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import subprocess
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            python_path = sys.executable
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ori_path = os.getcwd()
         | 
| 8 | 
            +
            os.chdir("ppocr/postprocess/pse_postprocess/pse")
         | 
| 9 | 
            +
            if (
         | 
| 10 | 
            +
                subprocess.call("{} setup.py build_ext --inplace".format(python_path), shell=True)
         | 
| 11 | 
            +
                != 0
         | 
| 12 | 
            +
            ):
         | 
| 13 | 
            +
                raise RuntimeError(
         | 
| 14 | 
            +
                    "Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+".format(
         | 
| 15 | 
            +
                        os.path.dirname(os.path.realpath(__file__))
         | 
| 16 | 
            +
                    )
         | 
| 17 | 
            +
                )
         | 
| 18 | 
            +
            os.chdir(ori_path)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .pse import pse
         | 
    	
        ocr/postprocess/pse_postprocess/pse/pse.pyx
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import cv2
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            cimport cython
         | 
| 6 | 
            +
            cimport libcpp
         | 
| 7 | 
            +
            cimport libcpp.pair
         | 
| 8 | 
            +
            cimport libcpp.queue
         | 
| 9 | 
            +
            cimport numpy as np
         | 
| 10 | 
            +
            from libcpp.pair cimport *
         | 
| 11 | 
            +
            from libcpp.queue cimport *
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            @cython.boundscheck(False)
         | 
| 15 | 
            +
            @cython.wraparound(False)
         | 
| 16 | 
            +
            cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
         | 
| 17 | 
            +
                                                     np.ndarray[np.int32_t, ndim=2] label,
         | 
| 18 | 
            +
                                                     int kernel_num,
         | 
| 19 | 
            +
                                                     int label_num,
         | 
| 20 | 
            +
                                                     float min_area=0):
         | 
| 21 | 
            +
                cdef np.ndarray[np.int32_t, ndim=2] pred
         | 
| 22 | 
            +
                pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                for label_idx in range(1, label_num):
         | 
| 25 | 
            +
                    if np.sum(label == label_idx) < min_area:
         | 
| 26 | 
            +
                        label[label == label_idx] = 0
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
         | 
| 29 | 
            +
                    queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
         | 
| 30 | 
            +
                cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
         | 
| 31 | 
            +
                    queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
         | 
| 32 | 
            +
                cdef np.int16_t* dx = [-1, 1, 0, 0]
         | 
| 33 | 
            +
                cdef np.int16_t* dy = [0, 0, -1, 1]
         | 
| 34 | 
            +
                cdef np.int16_t tmpx, tmpy
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                points = np.array(np.where(label > 0)).transpose((1, 0))
         | 
| 37 | 
            +
                for point_idx in range(points.shape[0]):
         | 
| 38 | 
            +
                    tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
         | 
| 39 | 
            +
                    que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
         | 
| 40 | 
            +
                    pred[tmpx, tmpy] = label[tmpx, tmpy]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
         | 
| 43 | 
            +
                cdef int cur_label
         | 
| 44 | 
            +
                for kernel_idx in range(kernel_num - 1, -1, -1):
         | 
| 45 | 
            +
                    while not que.empty():
         | 
| 46 | 
            +
                        cur = que.front()
         | 
| 47 | 
            +
                        que.pop()
         | 
| 48 | 
            +
                        cur_label = pred[cur.first, cur.second]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        is_edge = True
         | 
| 51 | 
            +
                        for j in range(4):
         | 
| 52 | 
            +
                            tmpx = cur.first + dx[j]
         | 
| 53 | 
            +
                            tmpy = cur.second + dy[j]
         | 
| 54 | 
            +
                            if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
         | 
| 55 | 
            +
                                continue
         | 
| 56 | 
            +
                            if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
         | 
| 57 | 
            +
                                continue
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                            que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
         | 
| 60 | 
            +
                            pred[tmpx, tmpy] = cur_label
         | 
| 61 | 
            +
                            is_edge = False
         | 
| 62 | 
            +
                        if is_edge:
         | 
| 63 | 
            +
                            nxt_que.push(cur)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    que, nxt_que = nxt_que, que
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                return pred
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            def pse(kernels, min_area):
         | 
| 70 | 
            +
                kernel_num = kernels.shape[0]
         | 
| 71 | 
            +
                label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
         | 
| 72 | 
            +
                return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
         | 
    	
        ocr/postprocess/pse_postprocess/pse/setup.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from distutils.core import Extension, setup
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy
         | 
| 4 | 
            +
            from Cython.Build import cythonize
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            setup(
         | 
| 7 | 
            +
                ext_modules=cythonize(
         | 
| 8 | 
            +
                    Extension(
         | 
| 9 | 
            +
                        "pse",
         | 
| 10 | 
            +
                        sources=["pse.pyx"],
         | 
| 11 | 
            +
                        language="c++",
         | 
| 12 | 
            +
                        include_dirs=[numpy.get_include()],
         | 
| 13 | 
            +
                        library_dirs=[],
         | 
| 14 | 
            +
                        libraries=[],
         | 
| 15 | 
            +
                        extra_compile_args=["-O3"],
         | 
| 16 | 
            +
                        extra_link_args=[],
         | 
| 17 | 
            +
                    )
         | 
| 18 | 
            +
                )
         | 
| 19 | 
            +
            )
         | 
    	
        ocr/postprocess/pse_postprocess/pse_postprocess.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import paddle
         | 
| 6 | 
            +
            from paddle.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .pse import pse
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class PSEPostProcess(object):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                The post process for PSE.
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                    self,
         | 
| 18 | 
            +
                    thresh=0.5,
         | 
| 19 | 
            +
                    box_thresh=0.85,
         | 
| 20 | 
            +
                    min_area=16,
         | 
| 21 | 
            +
                    box_type="quad",
         | 
| 22 | 
            +
                    scale=4,
         | 
| 23 | 
            +
                    **kwargs
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    assert box_type in ["quad", "poly"], "Only quad and poly is supported"
         | 
| 26 | 
            +
                    self.thresh = thresh
         | 
| 27 | 
            +
                    self.box_thresh = box_thresh
         | 
| 28 | 
            +
                    self.min_area = min_area
         | 
| 29 | 
            +
                    self.box_type = box_type
         | 
| 30 | 
            +
                    self.scale = scale
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __call__(self, outs_dict, shape_list):
         | 
| 33 | 
            +
                    pred = outs_dict["maps"]
         | 
| 34 | 
            +
                    if not isinstance(pred, paddle.Tensor):
         | 
| 35 | 
            +
                        pred = paddle.to_tensor(pred)
         | 
| 36 | 
            +
                    pred = F.interpolate(pred, scale_factor=4 // self.scale, mode="bilinear")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    score = F.sigmoid(pred[:, 0, :, :])
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    kernels = (pred > self.thresh).astype("float32")
         | 
| 41 | 
            +
                    text_mask = kernels[:, 0, :, :]
         | 
| 42 | 
            +
                    kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    score = score.numpy()
         | 
| 45 | 
            +
                    kernels = kernels.numpy().astype(np.uint8)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    boxes_batch = []
         | 
| 48 | 
            +
                    for batch_index in range(pred.shape[0]):
         | 
| 49 | 
            +
                        boxes, scores = self.boxes_from_bitmap(
         | 
| 50 | 
            +
                            score[batch_index], kernels[batch_index], shape_list[batch_index]
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        boxes_batch.append({"points": boxes, "scores": scores})
         | 
| 54 | 
            +
                    return boxes_batch
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def boxes_from_bitmap(self, score, kernels, shape):
         | 
| 57 | 
            +
                    label = pse(kernels, self.min_area)
         | 
| 58 | 
            +
                    return self.generate_box(score, label, shape)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def generate_box(self, score, label, shape):
         | 
| 61 | 
            +
                    src_h, src_w, ratio_h, ratio_w = shape
         | 
| 62 | 
            +
                    label_num = np.max(label) + 1
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    boxes = []
         | 
| 65 | 
            +
                    scores = []
         | 
| 66 | 
            +
                    for i in range(1, label_num):
         | 
| 67 | 
            +
                        ind = label == i
         | 
| 68 | 
            +
                        points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        if points.shape[0] < self.min_area:
         | 
| 71 | 
            +
                            label[ind] = 0
         | 
| 72 | 
            +
                            continue
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                        score_i = np.mean(score[ind])
         | 
| 75 | 
            +
                        if score_i < self.box_thresh:
         | 
| 76 | 
            +
                            label[ind] = 0
         | 
| 77 | 
            +
                            continue
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        if self.box_type == "quad":
         | 
| 80 | 
            +
                            rect = cv2.minAreaRect(points)
         | 
| 81 | 
            +
                            bbox = cv2.boxPoints(rect)
         | 
| 82 | 
            +
                        elif self.box_type == "poly":
         | 
| 83 | 
            +
                            box_height = np.max(points[:, 1]) + 10
         | 
| 84 | 
            +
                            box_width = np.max(points[:, 0]) + 10
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                            mask = np.zeros((box_height, box_width), np.uint8)
         | 
| 87 | 
            +
                            mask[points[:, 1], points[:, 0]] = 255
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                            contours, _ = cv2.findContours(
         | 
| 90 | 
            +
                                mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
         | 
| 91 | 
            +
                            )
         | 
| 92 | 
            +
                            bbox = np.squeeze(contours[0], 1)
         | 
| 93 | 
            +
                        else:
         | 
| 94 | 
            +
                            raise NotImplementedError
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
         | 
| 97 | 
            +
                        bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
         | 
| 98 | 
            +
                        boxes.append(bbox)
         | 
| 99 | 
            +
                        scores.append(score_i)
         | 
| 100 | 
            +
                    return boxes, scores
         | 
    	
        ocr/postprocess/rec_postprocess.py
    ADDED
    
    | @@ -0,0 +1,731 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import paddle
         | 
| 5 | 
            +
            from paddle.nn import functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class BaseRecLabelDecode(object):
         | 
| 9 | 
            +
                """Convert between text-label and text-index"""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False):
         | 
| 12 | 
            +
                    self.beg_str = "sos"
         | 
| 13 | 
            +
                    self.end_str = "eos"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    self.character_str = []
         | 
| 16 | 
            +
                    if character_dict_path is None:
         | 
| 17 | 
            +
                        self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
         | 
| 18 | 
            +
                        dict_character = list(self.character_str)
         | 
| 19 | 
            +
                    else:
         | 
| 20 | 
            +
                        with open(character_dict_path, "rb") as fin:
         | 
| 21 | 
            +
                            lines = fin.readlines()
         | 
| 22 | 
            +
                            for line in lines:
         | 
| 23 | 
            +
                                line = line.decode("utf-8").strip("\n").strip("\r\n")
         | 
| 24 | 
            +
                                self.character_str.append(line)
         | 
| 25 | 
            +
                        if use_space_char:
         | 
| 26 | 
            +
                            self.character_str.append(" ")
         | 
| 27 | 
            +
                        dict_character = list(self.character_str)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    dict_character = self.add_special_char(dict_character)
         | 
| 30 | 
            +
                    self.dict = {}
         | 
| 31 | 
            +
                    for i, char in enumerate(dict_character):
         | 
| 32 | 
            +
                        self.dict[char] = i
         | 
| 33 | 
            +
                    self.character = dict_character
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def add_special_char(self, dict_character):
         | 
| 36 | 
            +
                    return dict_character
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 39 | 
            +
                    """convert text-index into text-label."""
         | 
| 40 | 
            +
                    result_list = []
         | 
| 41 | 
            +
                    ignored_tokens = self.get_ignored_tokens()
         | 
| 42 | 
            +
                    batch_size = len(text_index)
         | 
| 43 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 44 | 
            +
                        selection = np.ones(len(text_index[batch_idx]), dtype=bool)
         | 
| 45 | 
            +
                        if is_remove_duplicate:
         | 
| 46 | 
            +
                            selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
         | 
| 47 | 
            +
                        for ignored_token in ignored_tokens:
         | 
| 48 | 
            +
                            selection &= text_index[batch_idx] != ignored_token
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        char_list = [
         | 
| 51 | 
            +
                            self.character[text_id] for text_id in text_index[batch_idx][selection]
         | 
| 52 | 
            +
                        ]
         | 
| 53 | 
            +
                        if text_prob is not None:
         | 
| 54 | 
            +
                            conf_list = text_prob[batch_idx][selection]
         | 
| 55 | 
            +
                        else:
         | 
| 56 | 
            +
                            conf_list = [1] * len(selection)
         | 
| 57 | 
            +
                        if len(conf_list) == 0:
         | 
| 58 | 
            +
                            conf_list = [0]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        text = "".join(char_list)
         | 
| 61 | 
            +
                        result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 62 | 
            +
                    return result_list
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def get_ignored_tokens(self):
         | 
| 65 | 
            +
                    return [0]  # for ctc blank
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class CTCLabelDecode(BaseRecLabelDecode):
         | 
| 69 | 
            +
                """Convert between text-label and text-index"""
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 72 | 
            +
                    super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 75 | 
            +
                    if isinstance(preds, tuple) or isinstance(preds, list):
         | 
| 76 | 
            +
                        preds = preds[-1]
         | 
| 77 | 
            +
                    if isinstance(preds, paddle.Tensor):
         | 
| 78 | 
            +
                        preds = preds.numpy()
         | 
| 79 | 
            +
                    preds_idx = preds.argmax(axis=2)
         | 
| 80 | 
            +
                    preds_prob = preds.max(axis=2)
         | 
| 81 | 
            +
                    text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
         | 
| 82 | 
            +
                    if label is None:
         | 
| 83 | 
            +
                        return text
         | 
| 84 | 
            +
                    label = self.decode(label)
         | 
| 85 | 
            +
                    return text, label
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def add_special_char(self, dict_character):
         | 
| 88 | 
            +
                    dict_character = ["blank"] + dict_character
         | 
| 89 | 
            +
                    return dict_character
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class DistillationCTCLabelDecode(CTCLabelDecode):
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Convert
         | 
| 95 | 
            +
                Convert between text-label and text-index
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def __init__(
         | 
| 99 | 
            +
                    self,
         | 
| 100 | 
            +
                    character_dict_path=None,
         | 
| 101 | 
            +
                    use_space_char=False,
         | 
| 102 | 
            +
                    model_name=["student"],
         | 
| 103 | 
            +
                    key=None,
         | 
| 104 | 
            +
                    multi_head=False,
         | 
| 105 | 
            +
                    **kwargs
         | 
| 106 | 
            +
                ):
         | 
| 107 | 
            +
                    super(DistillationCTCLabelDecode, self).__init__(
         | 
| 108 | 
            +
                        character_dict_path, use_space_char
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                    if not isinstance(model_name, list):
         | 
| 111 | 
            +
                        model_name = [model_name]
         | 
| 112 | 
            +
                    self.model_name = model_name
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.key = key
         | 
| 115 | 
            +
                    self.multi_head = multi_head
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 118 | 
            +
                    output = dict()
         | 
| 119 | 
            +
                    for name in self.model_name:
         | 
| 120 | 
            +
                        pred = preds[name]
         | 
| 121 | 
            +
                        if self.key is not None:
         | 
| 122 | 
            +
                            pred = pred[self.key]
         | 
| 123 | 
            +
                        if self.multi_head and isinstance(pred, dict):
         | 
| 124 | 
            +
                            pred = pred["ctc"]
         | 
| 125 | 
            +
                        output[name] = super().__call__(pred, label=label, *args, **kwargs)
         | 
| 126 | 
            +
                    return output
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            class NRTRLabelDecode(BaseRecLabelDecode):
         | 
| 130 | 
            +
                """Convert between text-label and text-index"""
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
         | 
| 133 | 
            +
                    super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if len(preds) == 2:
         | 
| 138 | 
            +
                        preds_id = preds[0]
         | 
| 139 | 
            +
                        preds_prob = preds[1]
         | 
| 140 | 
            +
                        if isinstance(preds_id, paddle.Tensor):
         | 
| 141 | 
            +
                            preds_id = preds_id.numpy()
         | 
| 142 | 
            +
                        if isinstance(preds_prob, paddle.Tensor):
         | 
| 143 | 
            +
                            preds_prob = preds_prob.numpy()
         | 
| 144 | 
            +
                        if preds_id[0][0] == 2:
         | 
| 145 | 
            +
                            preds_idx = preds_id[:, 1:]
         | 
| 146 | 
            +
                            preds_prob = preds_prob[:, 1:]
         | 
| 147 | 
            +
                        else:
         | 
| 148 | 
            +
                            preds_idx = preds_id
         | 
| 149 | 
            +
                        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 150 | 
            +
                        if label is None:
         | 
| 151 | 
            +
                            return text
         | 
| 152 | 
            +
                        label = self.decode(label[:, 1:])
         | 
| 153 | 
            +
                    else:
         | 
| 154 | 
            +
                        if isinstance(preds, paddle.Tensor):
         | 
| 155 | 
            +
                            preds = preds.numpy()
         | 
| 156 | 
            +
                        preds_idx = preds.argmax(axis=2)
         | 
| 157 | 
            +
                        preds_prob = preds.max(axis=2)
         | 
| 158 | 
            +
                        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 159 | 
            +
                        if label is None:
         | 
| 160 | 
            +
                            return text
         | 
| 161 | 
            +
                        label = self.decode(label[:, 1:])
         | 
| 162 | 
            +
                    return text, label
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def add_special_char(self, dict_character):
         | 
| 165 | 
            +
                    dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
         | 
| 166 | 
            +
                    return dict_character
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 169 | 
            +
                    """convert text-index into text-label."""
         | 
| 170 | 
            +
                    result_list = []
         | 
| 171 | 
            +
                    batch_size = len(text_index)
         | 
| 172 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 173 | 
            +
                        char_list = []
         | 
| 174 | 
            +
                        conf_list = []
         | 
| 175 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 176 | 
            +
                            if text_index[batch_idx][idx] == 3:  # end
         | 
| 177 | 
            +
                                break
         | 
| 178 | 
            +
                            try:
         | 
| 179 | 
            +
                                char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 180 | 
            +
                            except:
         | 
| 181 | 
            +
                                continue
         | 
| 182 | 
            +
                            if text_prob is not None:
         | 
| 183 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 184 | 
            +
                            else:
         | 
| 185 | 
            +
                                conf_list.append(1)
         | 
| 186 | 
            +
                        text = "".join(char_list)
         | 
| 187 | 
            +
                        result_list.append((text.lower(), np.mean(conf_list).tolist()))
         | 
| 188 | 
            +
                    return result_list
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            class AttnLabelDecode(BaseRecLabelDecode):
         | 
| 192 | 
            +
                """Convert between text-label and text-index"""
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 195 | 
            +
                    super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def add_special_char(self, dict_character):
         | 
| 198 | 
            +
                    self.beg_str = "sos"
         | 
| 199 | 
            +
                    self.end_str = "eos"
         | 
| 200 | 
            +
                    dict_character = dict_character
         | 
| 201 | 
            +
                    dict_character = [self.beg_str] + dict_character + [self.end_str]
         | 
| 202 | 
            +
                    return dict_character
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 205 | 
            +
                    """convert text-index into text-label."""
         | 
| 206 | 
            +
                    result_list = []
         | 
| 207 | 
            +
                    ignored_tokens = self.get_ignored_tokens()
         | 
| 208 | 
            +
                    [beg_idx, end_idx] = self.get_ignored_tokens()
         | 
| 209 | 
            +
                    batch_size = len(text_index)
         | 
| 210 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 211 | 
            +
                        char_list = []
         | 
| 212 | 
            +
                        conf_list = []
         | 
| 213 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 214 | 
            +
                            if text_index[batch_idx][idx] in ignored_tokens:
         | 
| 215 | 
            +
                                continue
         | 
| 216 | 
            +
                            if int(text_index[batch_idx][idx]) == int(end_idx):
         | 
| 217 | 
            +
                                break
         | 
| 218 | 
            +
                            if is_remove_duplicate:
         | 
| 219 | 
            +
                                # only for predict
         | 
| 220 | 
            +
                                if (
         | 
| 221 | 
            +
                                    idx > 0
         | 
| 222 | 
            +
                                    and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
         | 
| 223 | 
            +
                                ):
         | 
| 224 | 
            +
                                    continue
         | 
| 225 | 
            +
                            char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 226 | 
            +
                            if text_prob is not None:
         | 
| 227 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 228 | 
            +
                            else:
         | 
| 229 | 
            +
                                conf_list.append(1)
         | 
| 230 | 
            +
                        text = "".join(char_list)
         | 
| 231 | 
            +
                        result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 232 | 
            +
                    return result_list
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 235 | 
            +
                    """
         | 
| 236 | 
            +
                    text = self.decode(text)
         | 
| 237 | 
            +
                    if label is None:
         | 
| 238 | 
            +
                        return text
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        label = self.decode(label, is_remove_duplicate=False)
         | 
| 241 | 
            +
                        return text, label
         | 
| 242 | 
            +
                    """
         | 
| 243 | 
            +
                    if isinstance(preds, paddle.Tensor):
         | 
| 244 | 
            +
                        preds = preds.numpy()
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    preds_idx = preds.argmax(axis=2)
         | 
| 247 | 
            +
                    preds_prob = preds.max(axis=2)
         | 
| 248 | 
            +
                    text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 249 | 
            +
                    if label is None:
         | 
| 250 | 
            +
                        return text
         | 
| 251 | 
            +
                    label = self.decode(label, is_remove_duplicate=False)
         | 
| 252 | 
            +
                    return text, label
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def get_ignored_tokens(self):
         | 
| 255 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg")
         | 
| 256 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end")
         | 
| 257 | 
            +
                    return [beg_idx, end_idx]
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end):
         | 
| 260 | 
            +
                    if beg_or_end == "beg":
         | 
| 261 | 
            +
                        idx = np.array(self.dict[self.beg_str])
         | 
| 262 | 
            +
                    elif beg_or_end == "end":
         | 
| 263 | 
            +
                        idx = np.array(self.dict[self.end_str])
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
         | 
| 266 | 
            +
                    return idx
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            class SEEDLabelDecode(BaseRecLabelDecode):
         | 
| 270 | 
            +
                """Convert between text-label and text-index"""
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 273 | 
            +
                    super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def add_special_char(self, dict_character):
         | 
| 276 | 
            +
                    self.padding_str = "padding"
         | 
| 277 | 
            +
                    self.end_str = "eos"
         | 
| 278 | 
            +
                    self.unknown = "unknown"
         | 
| 279 | 
            +
                    dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
         | 
| 280 | 
            +
                    return dict_character
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def get_ignored_tokens(self):
         | 
| 283 | 
            +
                    end_idx = self.get_beg_end_flag_idx("eos")
         | 
| 284 | 
            +
                    return [end_idx]
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end):
         | 
| 287 | 
            +
                    if beg_or_end == "sos":
         | 
| 288 | 
            +
                        idx = np.array(self.dict[self.beg_str])
         | 
| 289 | 
            +
                    elif beg_or_end == "eos":
         | 
| 290 | 
            +
                        idx = np.array(self.dict[self.end_str])
         | 
| 291 | 
            +
                    else:
         | 
| 292 | 
            +
                        assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
         | 
| 293 | 
            +
                    return idx
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 296 | 
            +
                    """convert text-index into text-label."""
         | 
| 297 | 
            +
                    result_list = []
         | 
| 298 | 
            +
                    [end_idx] = self.get_ignored_tokens()
         | 
| 299 | 
            +
                    batch_size = len(text_index)
         | 
| 300 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 301 | 
            +
                        char_list = []
         | 
| 302 | 
            +
                        conf_list = []
         | 
| 303 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 304 | 
            +
                            if int(text_index[batch_idx][idx]) == int(end_idx):
         | 
| 305 | 
            +
                                break
         | 
| 306 | 
            +
                            if is_remove_duplicate:
         | 
| 307 | 
            +
                                # only for predict
         | 
| 308 | 
            +
                                if (
         | 
| 309 | 
            +
                                    idx > 0
         | 
| 310 | 
            +
                                    and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
         | 
| 311 | 
            +
                                ):
         | 
| 312 | 
            +
                                    continue
         | 
| 313 | 
            +
                            char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 314 | 
            +
                            if text_prob is not None:
         | 
| 315 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 316 | 
            +
                            else:
         | 
| 317 | 
            +
                                conf_list.append(1)
         | 
| 318 | 
            +
                        text = "".join(char_list)
         | 
| 319 | 
            +
                        result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 320 | 
            +
                    return result_list
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 323 | 
            +
                    """
         | 
| 324 | 
            +
                    text = self.decode(text)
         | 
| 325 | 
            +
                    if label is None:
         | 
| 326 | 
            +
                        return text
         | 
| 327 | 
            +
                    else:
         | 
| 328 | 
            +
                        label = self.decode(label, is_remove_duplicate=False)
         | 
| 329 | 
            +
                        return text, label
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    preds_idx = preds["rec_pred"]
         | 
| 332 | 
            +
                    if isinstance(preds_idx, paddle.Tensor):
         | 
| 333 | 
            +
                        preds_idx = preds_idx.numpy()
         | 
| 334 | 
            +
                    if "rec_pred_scores" in preds:
         | 
| 335 | 
            +
                        preds_idx = preds["rec_pred"]
         | 
| 336 | 
            +
                        preds_prob = preds["rec_pred_scores"]
         | 
| 337 | 
            +
                    else:
         | 
| 338 | 
            +
                        preds_idx = preds["rec_pred"].argmax(axis=2)
         | 
| 339 | 
            +
                        preds_prob = preds["rec_pred"].max(axis=2)
         | 
| 340 | 
            +
                    text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 341 | 
            +
                    if label is None:
         | 
| 342 | 
            +
                        return text
         | 
| 343 | 
            +
                    label = self.decode(label, is_remove_duplicate=False)
         | 
| 344 | 
            +
                    return text, label
         | 
| 345 | 
            +
             | 
| 346 | 
            +
             | 
| 347 | 
            +
            class SRNLabelDecode(BaseRecLabelDecode):
         | 
| 348 | 
            +
                """Convert between text-label and text-index"""
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 351 | 
            +
                    super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 352 | 
            +
                    self.max_text_length = kwargs.get("max_text_length", 25)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 355 | 
            +
                    pred = preds["predict"]
         | 
| 356 | 
            +
                    char_num = len(self.character_str) + 2
         | 
| 357 | 
            +
                    if isinstance(pred, paddle.Tensor):
         | 
| 358 | 
            +
                        pred = pred.numpy()
         | 
| 359 | 
            +
                    pred = np.reshape(pred, [-1, char_num])
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    preds_idx = np.argmax(pred, axis=1)
         | 
| 362 | 
            +
                    preds_prob = np.max(pred, axis=1)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    text = self.decode(preds_idx, preds_prob)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    if label is None:
         | 
| 371 | 
            +
                        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 372 | 
            +
                        return text
         | 
| 373 | 
            +
                    label = self.decode(label)
         | 
| 374 | 
            +
                    return text, label
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 377 | 
            +
                    """convert text-index into text-label."""
         | 
| 378 | 
            +
                    result_list = []
         | 
| 379 | 
            +
                    ignored_tokens = self.get_ignored_tokens()
         | 
| 380 | 
            +
                    batch_size = len(text_index)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 383 | 
            +
                        char_list = []
         | 
| 384 | 
            +
                        conf_list = []
         | 
| 385 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 386 | 
            +
                            if text_index[batch_idx][idx] in ignored_tokens:
         | 
| 387 | 
            +
                                continue
         | 
| 388 | 
            +
                            if is_remove_duplicate:
         | 
| 389 | 
            +
                                # only for predict
         | 
| 390 | 
            +
                                if (
         | 
| 391 | 
            +
                                    idx > 0
         | 
| 392 | 
            +
                                    and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
         | 
| 393 | 
            +
                                ):
         | 
| 394 | 
            +
                                    continue
         | 
| 395 | 
            +
                            char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 396 | 
            +
                            if text_prob is not None:
         | 
| 397 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 398 | 
            +
                            else:
         | 
| 399 | 
            +
                                conf_list.append(1)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        text = "".join(char_list)
         | 
| 402 | 
            +
                        result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 403 | 
            +
                    return result_list
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                def add_special_char(self, dict_character):
         | 
| 406 | 
            +
                    dict_character = dict_character + [self.beg_str, self.end_str]
         | 
| 407 | 
            +
                    return dict_character
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                def get_ignored_tokens(self):
         | 
| 410 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg")
         | 
| 411 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end")
         | 
| 412 | 
            +
                    return [beg_idx, end_idx]
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end):
         | 
| 415 | 
            +
                    if beg_or_end == "beg":
         | 
| 416 | 
            +
                        idx = np.array(self.dict[self.beg_str])
         | 
| 417 | 
            +
                    elif beg_or_end == "end":
         | 
| 418 | 
            +
                        idx = np.array(self.dict[self.end_str])
         | 
| 419 | 
            +
                    else:
         | 
| 420 | 
            +
                        assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
         | 
| 421 | 
            +
                    return idx
         | 
| 422 | 
            +
             | 
| 423 | 
            +
             | 
| 424 | 
            +
            class TableLabelDecode(object):
         | 
| 425 | 
            +
                """ """
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                def __init__(self, character_dict_path, **kwargs):
         | 
| 428 | 
            +
                    list_character, list_elem = self.load_char_elem_dict(character_dict_path)
         | 
| 429 | 
            +
                    list_character = self.add_special_char(list_character)
         | 
| 430 | 
            +
                    list_elem = self.add_special_char(list_elem)
         | 
| 431 | 
            +
                    self.dict_character = {}
         | 
| 432 | 
            +
                    self.dict_idx_character = {}
         | 
| 433 | 
            +
                    for i, char in enumerate(list_character):
         | 
| 434 | 
            +
                        self.dict_idx_character[i] = char
         | 
| 435 | 
            +
                        self.dict_character[char] = i
         | 
| 436 | 
            +
                    self.dict_elem = {}
         | 
| 437 | 
            +
                    self.dict_idx_elem = {}
         | 
| 438 | 
            +
                    for i, elem in enumerate(list_elem):
         | 
| 439 | 
            +
                        self.dict_idx_elem[i] = elem
         | 
| 440 | 
            +
                        self.dict_elem[elem] = i
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def load_char_elem_dict(self, character_dict_path):
         | 
| 443 | 
            +
                    list_character = []
         | 
| 444 | 
            +
                    list_elem = []
         | 
| 445 | 
            +
                    with open(character_dict_path, "rb") as fin:
         | 
| 446 | 
            +
                        lines = fin.readlines()
         | 
| 447 | 
            +
                        substr = lines[0].decode("utf-8").strip("\n").strip("\r\n").split("\t")
         | 
| 448 | 
            +
                        character_num = int(substr[0])
         | 
| 449 | 
            +
                        elem_num = int(substr[1])
         | 
| 450 | 
            +
                        for cno in range(1, 1 + character_num):
         | 
| 451 | 
            +
                            character = lines[cno].decode("utf-8").strip("\n").strip("\r\n")
         | 
| 452 | 
            +
                            list_character.append(character)
         | 
| 453 | 
            +
                        for eno in range(1 + character_num, 1 + character_num + elem_num):
         | 
| 454 | 
            +
                            elem = lines[eno].decode("utf-8").strip("\n").strip("\r\n")
         | 
| 455 | 
            +
                            list_elem.append(elem)
         | 
| 456 | 
            +
                    return list_character, list_elem
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                def add_special_char(self, list_character):
         | 
| 459 | 
            +
                    self.beg_str = "sos"
         | 
| 460 | 
            +
                    self.end_str = "eos"
         | 
| 461 | 
            +
                    list_character = [self.beg_str] + list_character + [self.end_str]
         | 
| 462 | 
            +
                    return list_character
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                def __call__(self, preds):
         | 
| 465 | 
            +
                    structure_probs = preds["structure_probs"]
         | 
| 466 | 
            +
                    loc_preds = preds["loc_preds"]
         | 
| 467 | 
            +
                    if isinstance(structure_probs, paddle.Tensor):
         | 
| 468 | 
            +
                        structure_probs = structure_probs.numpy()
         | 
| 469 | 
            +
                    if isinstance(loc_preds, paddle.Tensor):
         | 
| 470 | 
            +
                        loc_preds = loc_preds.numpy()
         | 
| 471 | 
            +
                    structure_idx = structure_probs.argmax(axis=2)
         | 
| 472 | 
            +
                    structure_probs = structure_probs.max(axis=2)
         | 
| 473 | 
            +
                    (
         | 
| 474 | 
            +
                        structure_str,
         | 
| 475 | 
            +
                        structure_pos,
         | 
| 476 | 
            +
                        result_score_list,
         | 
| 477 | 
            +
                        result_elem_idx_list,
         | 
| 478 | 
            +
                    ) = self.decode(structure_idx, structure_probs, "elem")
         | 
| 479 | 
            +
                    res_html_code_list = []
         | 
| 480 | 
            +
                    res_loc_list = []
         | 
| 481 | 
            +
                    batch_num = len(structure_str)
         | 
| 482 | 
            +
                    for bno in range(batch_num):
         | 
| 483 | 
            +
                        res_loc = []
         | 
| 484 | 
            +
                        for sno in range(len(structure_str[bno])):
         | 
| 485 | 
            +
                            text = structure_str[bno][sno]
         | 
| 486 | 
            +
                            if text in ["<td>", "<td"]:
         | 
| 487 | 
            +
                                pos = structure_pos[bno][sno]
         | 
| 488 | 
            +
                                res_loc.append(loc_preds[bno, pos])
         | 
| 489 | 
            +
                        res_html_code = "".join(structure_str[bno])
         | 
| 490 | 
            +
                        res_loc = np.array(res_loc)
         | 
| 491 | 
            +
                        res_html_code_list.append(res_html_code)
         | 
| 492 | 
            +
                        res_loc_list.append(res_loc)
         | 
| 493 | 
            +
                    return {
         | 
| 494 | 
            +
                        "res_html_code": res_html_code_list,
         | 
| 495 | 
            +
                        "res_loc": res_loc_list,
         | 
| 496 | 
            +
                        "res_score_list": result_score_list,
         | 
| 497 | 
            +
                        "res_elem_idx_list": result_elem_idx_list,
         | 
| 498 | 
            +
                        "structure_str_list": structure_str,
         | 
| 499 | 
            +
                    }
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                def decode(self, text_index, structure_probs, char_or_elem):
         | 
| 502 | 
            +
                    """convert text-label into text-index."""
         | 
| 503 | 
            +
                    if char_or_elem == "char":
         | 
| 504 | 
            +
                        current_dict = self.dict_idx_character
         | 
| 505 | 
            +
                    else:
         | 
| 506 | 
            +
                        current_dict = self.dict_idx_elem
         | 
| 507 | 
            +
                        ignored_tokens = self.get_ignored_tokens("elem")
         | 
| 508 | 
            +
                        beg_idx, end_idx = ignored_tokens
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    result_list = []
         | 
| 511 | 
            +
                    result_pos_list = []
         | 
| 512 | 
            +
                    result_score_list = []
         | 
| 513 | 
            +
                    result_elem_idx_list = []
         | 
| 514 | 
            +
                    batch_size = len(text_index)
         | 
| 515 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 516 | 
            +
                        char_list = []
         | 
| 517 | 
            +
                        elem_pos_list = []
         | 
| 518 | 
            +
                        elem_idx_list = []
         | 
| 519 | 
            +
                        score_list = []
         | 
| 520 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 521 | 
            +
                            tmp_elem_idx = int(text_index[batch_idx][idx])
         | 
| 522 | 
            +
                            if idx > 0 and tmp_elem_idx == end_idx:
         | 
| 523 | 
            +
                                break
         | 
| 524 | 
            +
                            if tmp_elem_idx in ignored_tokens:
         | 
| 525 | 
            +
                                continue
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                            char_list.append(current_dict[tmp_elem_idx])
         | 
| 528 | 
            +
                            elem_pos_list.append(idx)
         | 
| 529 | 
            +
                            score_list.append(structure_probs[batch_idx, idx])
         | 
| 530 | 
            +
                            elem_idx_list.append(tmp_elem_idx)
         | 
| 531 | 
            +
                        result_list.append(char_list)
         | 
| 532 | 
            +
                        result_pos_list.append(elem_pos_list)
         | 
| 533 | 
            +
                        result_score_list.append(score_list)
         | 
| 534 | 
            +
                        result_elem_idx_list.append(elem_idx_list)
         | 
| 535 | 
            +
                    return result_list, result_pos_list, result_score_list, result_elem_idx_list
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                def get_ignored_tokens(self, char_or_elem):
         | 
| 538 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
         | 
| 539 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
         | 
| 540 | 
            +
                    return [beg_idx, end_idx]
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
         | 
| 543 | 
            +
                    if char_or_elem == "char":
         | 
| 544 | 
            +
                        if beg_or_end == "beg":
         | 
| 545 | 
            +
                            idx = self.dict_character[self.beg_str]
         | 
| 546 | 
            +
                        elif beg_or_end == "end":
         | 
| 547 | 
            +
                            idx = self.dict_character[self.end_str]
         | 
| 548 | 
            +
                        else:
         | 
| 549 | 
            +
                            assert False, (
         | 
| 550 | 
            +
                                "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
         | 
| 551 | 
            +
                            )
         | 
| 552 | 
            +
                    elif char_or_elem == "elem":
         | 
| 553 | 
            +
                        if beg_or_end == "beg":
         | 
| 554 | 
            +
                            idx = self.dict_elem[self.beg_str]
         | 
| 555 | 
            +
                        elif beg_or_end == "end":
         | 
| 556 | 
            +
                            idx = self.dict_elem[self.end_str]
         | 
| 557 | 
            +
                        else:
         | 
| 558 | 
            +
                            assert False, (
         | 
| 559 | 
            +
                                "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
         | 
| 560 | 
            +
                            )
         | 
| 561 | 
            +
                    else:
         | 
| 562 | 
            +
                        assert False, "Unsupport type %s in char_or_elem" % char_or_elem
         | 
| 563 | 
            +
                    return idx
         | 
| 564 | 
            +
             | 
| 565 | 
            +
             | 
| 566 | 
            +
            class SARLabelDecode(BaseRecLabelDecode):
         | 
| 567 | 
            +
                """Convert between text-label and text-index"""
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 570 | 
            +
                    super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    self.rm_symbol = kwargs.get("rm_symbol", False)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                def add_special_char(self, dict_character):
         | 
| 575 | 
            +
                    beg_end_str = "<BOS/EOS>"
         | 
| 576 | 
            +
                    unknown_str = "<UKN>"
         | 
| 577 | 
            +
                    padding_str = "<PAD>"
         | 
| 578 | 
            +
                    dict_character = dict_character + [unknown_str]
         | 
| 579 | 
            +
                    self.unknown_idx = len(dict_character) - 1
         | 
| 580 | 
            +
                    dict_character = dict_character + [beg_end_str]
         | 
| 581 | 
            +
                    self.start_idx = len(dict_character) - 1
         | 
| 582 | 
            +
                    self.end_idx = len(dict_character) - 1
         | 
| 583 | 
            +
                    dict_character = dict_character + [padding_str]
         | 
| 584 | 
            +
                    self.padding_idx = len(dict_character) - 1
         | 
| 585 | 
            +
                    return dict_character
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
         | 
| 588 | 
            +
                    """convert text-index into text-label."""
         | 
| 589 | 
            +
                    result_list = []
         | 
| 590 | 
            +
                    ignored_tokens = self.get_ignored_tokens()
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    batch_size = len(text_index)
         | 
| 593 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 594 | 
            +
                        char_list = []
         | 
| 595 | 
            +
                        conf_list = []
         | 
| 596 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 597 | 
            +
                            if text_index[batch_idx][idx] in ignored_tokens:
         | 
| 598 | 
            +
                                continue
         | 
| 599 | 
            +
                            if int(text_index[batch_idx][idx]) == int(self.end_idx):
         | 
| 600 | 
            +
                                if text_prob is None and idx == 0:
         | 
| 601 | 
            +
                                    continue
         | 
| 602 | 
            +
                                else:
         | 
| 603 | 
            +
                                    break
         | 
| 604 | 
            +
                            if is_remove_duplicate:
         | 
| 605 | 
            +
                                # only for predict
         | 
| 606 | 
            +
                                if (
         | 
| 607 | 
            +
                                    idx > 0
         | 
| 608 | 
            +
                                    and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
         | 
| 609 | 
            +
                                ):
         | 
| 610 | 
            +
                                    continue
         | 
| 611 | 
            +
                            char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 612 | 
            +
                            if text_prob is not None:
         | 
| 613 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 614 | 
            +
                            else:
         | 
| 615 | 
            +
                                conf_list.append(1)
         | 
| 616 | 
            +
                        text = "".join(char_list)
         | 
| 617 | 
            +
                        if self.rm_symbol:
         | 
| 618 | 
            +
                            comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
         | 
| 619 | 
            +
                            text = text.lower()
         | 
| 620 | 
            +
                            text = comp.sub("", text)
         | 
| 621 | 
            +
                        result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 622 | 
            +
                    return result_list
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 625 | 
            +
                    if isinstance(preds, paddle.Tensor):
         | 
| 626 | 
            +
                        preds = preds.numpy()
         | 
| 627 | 
            +
                    preds_idx = preds.argmax(axis=2)
         | 
| 628 | 
            +
                    preds_prob = preds.max(axis=2)
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    if label is None:
         | 
| 633 | 
            +
                        return text
         | 
| 634 | 
            +
                    label = self.decode(label, is_remove_duplicate=False)
         | 
| 635 | 
            +
                    return text, label
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                def get_ignored_tokens(self):
         | 
| 638 | 
            +
                    return [self.padding_idx]
         | 
| 639 | 
            +
             | 
| 640 | 
            +
             | 
| 641 | 
            +
            class DistillationSARLabelDecode(SARLabelDecode):
         | 
| 642 | 
            +
                """
         | 
| 643 | 
            +
                Convert
         | 
| 644 | 
            +
                Convert between text-label and text-index
         | 
| 645 | 
            +
                """
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                def __init__(
         | 
| 648 | 
            +
                    self,
         | 
| 649 | 
            +
                    character_dict_path=None,
         | 
| 650 | 
            +
                    use_space_char=False,
         | 
| 651 | 
            +
                    model_name=["student"],
         | 
| 652 | 
            +
                    key=None,
         | 
| 653 | 
            +
                    multi_head=False,
         | 
| 654 | 
            +
                    **kwargs
         | 
| 655 | 
            +
                ):
         | 
| 656 | 
            +
                    super(DistillationSARLabelDecode, self).__init__(
         | 
| 657 | 
            +
                        character_dict_path, use_space_char
         | 
| 658 | 
            +
                    )
         | 
| 659 | 
            +
                    if not isinstance(model_name, list):
         | 
| 660 | 
            +
                        model_name = [model_name]
         | 
| 661 | 
            +
                    self.model_name = model_name
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    self.key = key
         | 
| 664 | 
            +
                    self.multi_head = multi_head
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 667 | 
            +
                    output = dict()
         | 
| 668 | 
            +
                    for name in self.model_name:
         | 
| 669 | 
            +
                        pred = preds[name]
         | 
| 670 | 
            +
                        if self.key is not None:
         | 
| 671 | 
            +
                            pred = pred[self.key]
         | 
| 672 | 
            +
                        if self.multi_head and isinstance(pred, dict):
         | 
| 673 | 
            +
                            pred = pred["sar"]
         | 
| 674 | 
            +
                        output[name] = super().__call__(pred, label=label, *args, **kwargs)
         | 
| 675 | 
            +
                    return output
         | 
| 676 | 
            +
             | 
| 677 | 
            +
             | 
| 678 | 
            +
            class PRENLabelDecode(BaseRecLabelDecode):
         | 
| 679 | 
            +
                """Convert between text-label and text-index"""
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
         | 
| 682 | 
            +
                    super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                def add_special_char(self, dict_character):
         | 
| 685 | 
            +
                    padding_str = "<PAD>"  # 0
         | 
| 686 | 
            +
                    end_str = "<EOS>"  # 1
         | 
| 687 | 
            +
                    unknown_str = "<UNK>"  # 2
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    dict_character = [padding_str, end_str, unknown_str] + dict_character
         | 
| 690 | 
            +
                    self.padding_idx = 0
         | 
| 691 | 
            +
                    self.end_idx = 1
         | 
| 692 | 
            +
                    self.unknown_idx = 2
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    return dict_character
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                def decode(self, text_index, text_prob=None):
         | 
| 697 | 
            +
                    """convert text-index into text-label."""
         | 
| 698 | 
            +
                    result_list = []
         | 
| 699 | 
            +
                    batch_size = len(text_index)
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                    for batch_idx in range(batch_size):
         | 
| 702 | 
            +
                        char_list = []
         | 
| 703 | 
            +
                        conf_list = []
         | 
| 704 | 
            +
                        for idx in range(len(text_index[batch_idx])):
         | 
| 705 | 
            +
                            if text_index[batch_idx][idx] == self.end_idx:
         | 
| 706 | 
            +
                                break
         | 
| 707 | 
            +
                            if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
         | 
| 708 | 
            +
                                continue
         | 
| 709 | 
            +
                            char_list.append(self.character[int(text_index[batch_idx][idx])])
         | 
| 710 | 
            +
                            if text_prob is not None:
         | 
| 711 | 
            +
                                conf_list.append(text_prob[batch_idx][idx])
         | 
| 712 | 
            +
                            else:
         | 
| 713 | 
            +
                                conf_list.append(1)
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                        text = "".join(char_list)
         | 
| 716 | 
            +
                        if len(text) > 0:
         | 
| 717 | 
            +
                            result_list.append((text, np.mean(conf_list).tolist()))
         | 
| 718 | 
            +
                        else:
         | 
| 719 | 
            +
                            # here confidence of empty recog result is 1
         | 
| 720 | 
            +
                            result_list.append(("", 1))
         | 
| 721 | 
            +
                    return result_list
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 724 | 
            +
                    preds = preds.numpy()
         | 
| 725 | 
            +
                    preds_idx = preds.argmax(axis=2)
         | 
| 726 | 
            +
                    preds_prob = preds.max(axis=2)
         | 
| 727 | 
            +
                    text = self.decode(preds_idx, preds_prob)
         | 
| 728 | 
            +
                    if label is None:
         | 
| 729 | 
            +
                        return text
         | 
| 730 | 
            +
                    label = self.decode(label)
         | 
| 731 | 
            +
                    return text, label
         | 
    	
        ocr/postprocess/sast_postprocess.py
    ADDED
    
    | @@ -0,0 +1,355 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __dir__ = os.path.dirname(__file__)
         | 
| 7 | 
            +
            sys.path.append(__dir__)
         | 
| 8 | 
            +
            sys.path.append(os.path.join(__dir__, ".."))
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import time
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import paddle
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .locality_aware_nms import nms_locality
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class SASTPostProcess(object):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                The post process for SAST.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    score_thresh=0.5,
         | 
| 27 | 
            +
                    nms_thresh=0.2,
         | 
| 28 | 
            +
                    sample_pts_num=2,
         | 
| 29 | 
            +
                    shrink_ratio_of_width=0.3,
         | 
| 30 | 
            +
                    expand_scale=1.0,
         | 
| 31 | 
            +
                    tcl_map_thresh=0.5,
         | 
| 32 | 
            +
                    **kwargs
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.score_thresh = score_thresh
         | 
| 36 | 
            +
                    self.nms_thresh = nms_thresh
         | 
| 37 | 
            +
                    self.sample_pts_num = sample_pts_num
         | 
| 38 | 
            +
                    self.shrink_ratio_of_width = shrink_ratio_of_width
         | 
| 39 | 
            +
                    self.expand_scale = expand_scale
         | 
| 40 | 
            +
                    self.tcl_map_thresh = tcl_map_thresh
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # c++ la-nms is faster, but only support python 3.5
         | 
| 43 | 
            +
                    self.is_python35 = False
         | 
| 44 | 
            +
                    if sys.version_info.major == 3 and sys.version_info.minor == 5:
         | 
| 45 | 
            +
                        self.is_python35 = True
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def point_pair2poly(self, point_pair_list):
         | 
| 48 | 
            +
                    """
         | 
| 49 | 
            +
                    Transfer vertical point_pairs into poly point in clockwise.
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    # constract poly
         | 
| 52 | 
            +
                    point_num = len(point_pair_list) * 2
         | 
| 53 | 
            +
                    point_list = [0] * point_num
         | 
| 54 | 
            +
                    for idx, point_pair in enumerate(point_pair_list):
         | 
| 55 | 
            +
                        point_list[idx] = point_pair[0]
         | 
| 56 | 
            +
                        point_list[point_num - 1 - idx] = point_pair[1]
         | 
| 57 | 
            +
                    return np.array(point_list).reshape(-1, 2)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
         | 
| 60 | 
            +
                    """
         | 
| 61 | 
            +
                    Generate shrink_quad_along_width.
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    ratio_pair = np.array(
         | 
| 64 | 
            +
                        [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
         | 
| 67 | 
            +
                    p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
         | 
| 68 | 
            +
                    return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
         | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    expand poly along width.
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    point_num = poly.shape[0]
         | 
| 75 | 
            +
                    left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
         | 
| 76 | 
            +
                    left_ratio = (
         | 
| 77 | 
            +
                        -shrink_ratio_of_width
         | 
| 78 | 
            +
                        * np.linalg.norm(left_quad[0] - left_quad[3])
         | 
| 79 | 
            +
                        / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
                    left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
         | 
| 82 | 
            +
                    right_quad = np.array(
         | 
| 83 | 
            +
                        [
         | 
| 84 | 
            +
                            poly[point_num // 2 - 2],
         | 
| 85 | 
            +
                            poly[point_num // 2 - 1],
         | 
| 86 | 
            +
                            poly[point_num // 2],
         | 
| 87 | 
            +
                            poly[point_num // 2 + 1],
         | 
| 88 | 
            +
                        ],
         | 
| 89 | 
            +
                        dtype=np.float32,
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
         | 
| 92 | 
            +
                        right_quad[0] - right_quad[3]
         | 
| 93 | 
            +
                    ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
         | 
| 94 | 
            +
                    right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
         | 
| 95 | 
            +
                    poly[0] = left_quad_expand[0]
         | 
| 96 | 
            +
                    poly[-1] = left_quad_expand[-1]
         | 
| 97 | 
            +
                    poly[point_num // 2 - 1] = right_quad_expand[1]
         | 
| 98 | 
            +
                    poly[point_num // 2] = right_quad_expand[2]
         | 
| 99 | 
            +
                    return poly
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
         | 
| 102 | 
            +
                    """Restore quad."""
         | 
| 103 | 
            +
                    xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
         | 
| 104 | 
            +
                    xy_text = xy_text[:, ::-1]  # (n, 2)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    # Sort the text boxes via the y axis
         | 
| 107 | 
            +
                    xy_text = xy_text[np.argsort(xy_text[:, 1])]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
         | 
| 110 | 
            +
                    scores = scores[:, np.newaxis]
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # Restore
         | 
| 113 | 
            +
                    point_num = int(tvo_map.shape[-1] / 2)
         | 
| 114 | 
            +
                    assert point_num == 4
         | 
| 115 | 
            +
                    tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
         | 
| 116 | 
            +
                    xy_text_tile = np.tile(xy_text, (1, point_num))  # (n, point_num * 2)
         | 
| 117 | 
            +
                    quads = xy_text_tile - tvo_map
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    return scores, quads, xy_text
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def quad_area(self, quad):
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    compute area of a quad.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    edge = [
         | 
| 126 | 
            +
                        (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
         | 
| 127 | 
            +
                        (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
         | 
| 128 | 
            +
                        (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
         | 
| 129 | 
            +
                        (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]),
         | 
| 130 | 
            +
                    ]
         | 
| 131 | 
            +
                    return np.sum(edge) / 2.0
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def nms(self, dets):
         | 
| 134 | 
            +
                    if self.is_python35:
         | 
| 135 | 
            +
                        import lanms
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        dets = nms_locality(dets, self.nms_thresh)
         | 
| 140 | 
            +
                    return dets
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
         | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
                    Cluster pixels in tcl_map based on quads.
         | 
| 145 | 
            +
                    """
         | 
| 146 | 
            +
                    instance_count = quads.shape[0] + 1  # contain background
         | 
| 147 | 
            +
                    instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
         | 
| 148 | 
            +
                    if instance_count == 1:
         | 
| 149 | 
            +
                        return instance_count, instance_label_map
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    # predict text center
         | 
| 152 | 
            +
                    xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
         | 
| 153 | 
            +
                    n = xy_text.shape[0]
         | 
| 154 | 
            +
                    xy_text = xy_text[:, ::-1]  # (n, 2)
         | 
| 155 | 
            +
                    tco = tco_map[xy_text[:, 1], xy_text[:, 0], :]  # (n, 2)
         | 
| 156 | 
            +
                    pred_tc = xy_text - tco
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # get gt text center
         | 
| 159 | 
            +
                    m = quads.shape[0]
         | 
| 160 | 
            +
                    gt_tc = np.mean(quads, axis=1)  # (m, 2)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1))  # (n, m, 2)
         | 
| 163 | 
            +
                    gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1))  # (n, m, 2)
         | 
| 164 | 
            +
                    dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2)  # (n, m)
         | 
| 165 | 
            +
                    xy_text_assign = np.argmin(dist_mat, axis=1) + 1  # (n,)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
         | 
| 168 | 
            +
                    return instance_count, instance_label_map
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def estimate_sample_pts_num(self, quad, xy_text):
         | 
| 171 | 
            +
                    """
         | 
| 172 | 
            +
                    Estimate sample points number.
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    eh = (
         | 
| 175 | 
            +
                        np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
         | 
| 176 | 
            +
                    ) / 2.0
         | 
| 177 | 
            +
                    ew = (
         | 
| 178 | 
            +
                        np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
         | 
| 179 | 
            +
                    ) / 2.0
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    dense_sample_pts_num = max(2, int(ew))
         | 
| 182 | 
            +
                    dense_xy_center_line = xy_text[
         | 
| 183 | 
            +
                        np.linspace(
         | 
| 184 | 
            +
                            0,
         | 
| 185 | 
            +
                            xy_text.shape[0] - 1,
         | 
| 186 | 
            +
                            dense_sample_pts_num,
         | 
| 187 | 
            +
                            endpoint=True,
         | 
| 188 | 
            +
                            dtype=np.float32,
         | 
| 189 | 
            +
                        ).astype(np.int32)
         | 
| 190 | 
            +
                    ]
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
         | 
| 193 | 
            +
                    estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    sample_pts_num = max(2, int(estimate_arc_len / eh))
         | 
| 196 | 
            +
                    return sample_pts_num
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def detect_sast(
         | 
| 199 | 
            +
                    self,
         | 
| 200 | 
            +
                    tcl_map,
         | 
| 201 | 
            +
                    tvo_map,
         | 
| 202 | 
            +
                    tbo_map,
         | 
| 203 | 
            +
                    tco_map,
         | 
| 204 | 
            +
                    ratio_w,
         | 
| 205 | 
            +
                    ratio_h,
         | 
| 206 | 
            +
                    src_w,
         | 
| 207 | 
            +
                    src_h,
         | 
| 208 | 
            +
                    shrink_ratio_of_width=0.3,
         | 
| 209 | 
            +
                    tcl_map_thresh=0.5,
         | 
| 210 | 
            +
                    offset_expand=1.0,
         | 
| 211 | 
            +
                    out_strid=4.0,
         | 
| 212 | 
            +
                ):
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    # restore quad
         | 
| 217 | 
            +
                    scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
         | 
| 218 | 
            +
                    dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
         | 
| 219 | 
            +
                    dets = self.nms(dets)
         | 
| 220 | 
            +
                    if dets.shape[0] == 0:
         | 
| 221 | 
            +
                        return []
         | 
| 222 | 
            +
                    quads = dets[:, :-1].reshape(-1, 4, 2)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # Compute quad area
         | 
| 225 | 
            +
                    quad_areas = []
         | 
| 226 | 
            +
                    for quad in quads:
         | 
| 227 | 
            +
                        quad_areas.append(-self.quad_area(quad))
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # instance segmentation
         | 
| 230 | 
            +
                    # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
         | 
| 231 | 
            +
                    instance_count, instance_label_map = self.cluster_by_quads_tco(
         | 
| 232 | 
            +
                        tcl_map, tcl_map_thresh, quads, tco_map
         | 
| 233 | 
            +
                    )
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    # restore single poly with tcl instance.
         | 
| 236 | 
            +
                    poly_list = []
         | 
| 237 | 
            +
                    for instance_idx in range(1, instance_count):
         | 
| 238 | 
            +
                        xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
         | 
| 239 | 
            +
                        quad = quads[instance_idx - 1]
         | 
| 240 | 
            +
                        q_area = quad_areas[instance_idx - 1]
         | 
| 241 | 
            +
                        if q_area < 5:
         | 
| 242 | 
            +
                            continue
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                        #
         | 
| 245 | 
            +
                        len1 = float(np.linalg.norm(quad[0] - quad[1]))
         | 
| 246 | 
            +
                        len2 = float(np.linalg.norm(quad[1] - quad[2]))
         | 
| 247 | 
            +
                        min_len = min(len1, len2)
         | 
| 248 | 
            +
                        if min_len < 3:
         | 
| 249 | 
            +
                            continue
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        # filter small CC
         | 
| 252 | 
            +
                        if xy_text.shape[0] <= 0:
         | 
| 253 | 
            +
                            continue
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                        # filter low confidence instance
         | 
| 256 | 
            +
                        xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
         | 
| 257 | 
            +
                        if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
         | 
| 258 | 
            +
                            # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
         | 
| 259 | 
            +
                            continue
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        # sort xy_text
         | 
| 262 | 
            +
                        left_center_pt = np.array(
         | 
| 263 | 
            +
                            [[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]
         | 
| 264 | 
            +
                        )  # (1, 2)
         | 
| 265 | 
            +
                        right_center_pt = np.array(
         | 
| 266 | 
            +
                            [[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]]
         | 
| 267 | 
            +
                        )  # (1, 2)
         | 
| 268 | 
            +
                        proj_unit_vec = (right_center_pt - left_center_pt) / (
         | 
| 269 | 
            +
                            np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                        proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
         | 
| 272 | 
            +
                        xy_text = xy_text[np.argsort(proj_value)]
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                        # Sample pts in tcl map
         | 
| 275 | 
            +
                        if self.sample_pts_num == 0:
         | 
| 276 | 
            +
                            sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
         | 
| 277 | 
            +
                        else:
         | 
| 278 | 
            +
                            sample_pts_num = self.sample_pts_num
         | 
| 279 | 
            +
                        xy_center_line = xy_text[
         | 
| 280 | 
            +
                            np.linspace(
         | 
| 281 | 
            +
                                0,
         | 
| 282 | 
            +
                                xy_text.shape[0] - 1,
         | 
| 283 | 
            +
                                sample_pts_num,
         | 
| 284 | 
            +
                                endpoint=True,
         | 
| 285 | 
            +
                                dtype=np.float32,
         | 
| 286 | 
            +
                            ).astype(np.int32)
         | 
| 287 | 
            +
                        ]
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                        point_pair_list = []
         | 
| 290 | 
            +
                        for x, y in xy_center_line:
         | 
| 291 | 
            +
                            # get corresponding offset
         | 
| 292 | 
            +
                            offset = tbo_map[y, x, :].reshape(2, 2)
         | 
| 293 | 
            +
                            if offset_expand != 1.0:
         | 
| 294 | 
            +
                                offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
         | 
| 295 | 
            +
                                expand_length = np.clip(
         | 
| 296 | 
            +
                                    offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
         | 
| 297 | 
            +
                                )
         | 
| 298 | 
            +
                                offset_detal = offset / offset_length * expand_length
         | 
| 299 | 
            +
                                offset = offset + offset_detal
         | 
| 300 | 
            +
                                # original point
         | 
| 301 | 
            +
                            ori_yx = np.array([y, x], dtype=np.float32)
         | 
| 302 | 
            +
                            point_pair = (
         | 
| 303 | 
            +
                                (ori_yx + offset)[:, ::-1]
         | 
| 304 | 
            +
                                * out_strid
         | 
| 305 | 
            +
                                / np.array([ratio_w, ratio_h]).reshape(-1, 2)
         | 
| 306 | 
            +
                            )
         | 
| 307 | 
            +
                            point_pair_list.append(point_pair)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                        # ndarry: (x, 2), expand poly along width
         | 
| 310 | 
            +
                        detected_poly = self.point_pair2poly(point_pair_list)
         | 
| 311 | 
            +
                        detected_poly = self.expand_poly_along_width(
         | 
| 312 | 
            +
                            detected_poly, shrink_ratio_of_width
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
                        detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
         | 
| 315 | 
            +
                        detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
         | 
| 316 | 
            +
                        poly_list.append(detected_poly)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    return poly_list
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                def __call__(self, outs_dict, shape_list):
         | 
| 321 | 
            +
                    score_list = outs_dict["f_score"]
         | 
| 322 | 
            +
                    border_list = outs_dict["f_border"]
         | 
| 323 | 
            +
                    tvo_list = outs_dict["f_tvo"]
         | 
| 324 | 
            +
                    tco_list = outs_dict["f_tco"]
         | 
| 325 | 
            +
                    if isinstance(score_list, paddle.Tensor):
         | 
| 326 | 
            +
                        score_list = score_list.numpy()
         | 
| 327 | 
            +
                        border_list = border_list.numpy()
         | 
| 328 | 
            +
                        tvo_list = tvo_list.numpy()
         | 
| 329 | 
            +
                        tco_list = tco_list.numpy()
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    img_num = len(shape_list)
         | 
| 332 | 
            +
                    poly_lists = []
         | 
| 333 | 
            +
                    for ino in range(img_num):
         | 
| 334 | 
            +
                        p_score = score_list[ino].transpose((1, 2, 0))
         | 
| 335 | 
            +
                        p_border = border_list[ino].transpose((1, 2, 0))
         | 
| 336 | 
            +
                        p_tvo = tvo_list[ino].transpose((1, 2, 0))
         | 
| 337 | 
            +
                        p_tco = tco_list[ino].transpose((1, 2, 0))
         | 
| 338 | 
            +
                        src_h, src_w, ratio_h, ratio_w = shape_list[ino]
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        poly_list = self.detect_sast(
         | 
| 341 | 
            +
                            p_score,
         | 
| 342 | 
            +
                            p_tvo,
         | 
| 343 | 
            +
                            p_border,
         | 
| 344 | 
            +
                            p_tco,
         | 
| 345 | 
            +
                            ratio_w,
         | 
| 346 | 
            +
                            ratio_h,
         | 
| 347 | 
            +
                            src_w,
         | 
| 348 | 
            +
                            src_h,
         | 
| 349 | 
            +
                            shrink_ratio_of_width=self.shrink_ratio_of_width,
         | 
| 350 | 
            +
                            tcl_map_thresh=self.tcl_map_thresh,
         | 
| 351 | 
            +
                            offset_expand=self.expand_scale,
         | 
| 352 | 
            +
                        )
         | 
| 353 | 
            +
                        poly_lists.append({"points": np.array(poly_list)})
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    return poly_lists
         | 
    	
        ocr/postprocess/vqa_token_re_layoutlm_postprocess.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            class VQAReTokenLayoutLMPostProcess(object):
         | 
| 2 | 
            +
                """Convert between text-label and text-index"""
         | 
| 3 | 
            +
             | 
| 4 | 
            +
                def __init__(self, **kwargs):
         | 
| 5 | 
            +
                    super(VQAReTokenLayoutLMPostProcess, self).__init__()
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def __call__(self, preds, label=None, *args, **kwargs):
         | 
| 8 | 
            +
                    if label is not None:
         | 
| 9 | 
            +
                        return self._metric(preds, label)
         | 
| 10 | 
            +
                    else:
         | 
| 11 | 
            +
                        return self._infer(preds, *args, **kwargs)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def _metric(self, preds, label):
         | 
| 14 | 
            +
                    return preds["pred_relations"], label[6], label[5]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def _infer(self, preds, *args, **kwargs):
         | 
| 17 | 
            +
                    ser_results = kwargs["ser_results"]
         | 
| 18 | 
            +
                    entity_idx_dict_batch = kwargs["entity_idx_dict_batch"]
         | 
| 19 | 
            +
                    pred_relations = preds["pred_relations"]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    # merge relations and ocr info
         | 
| 22 | 
            +
                    results = []
         | 
| 23 | 
            +
                    for pred_relation, ser_result, entity_idx_dict in zip(
         | 
| 24 | 
            +
                        pred_relations, ser_results, entity_idx_dict_batch
         | 
| 25 | 
            +
                    ):
         | 
| 26 | 
            +
                        result = []
         | 
| 27 | 
            +
                        used_tail_id = []
         | 
| 28 | 
            +
                        for relation in pred_relation:
         | 
| 29 | 
            +
                            if relation["tail_id"] in used_tail_id:
         | 
| 30 | 
            +
                                continue
         | 
| 31 | 
            +
                            used_tail_id.append(relation["tail_id"])
         | 
| 32 | 
            +
                            ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]]
         | 
| 33 | 
            +
                            ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]]
         | 
| 34 | 
            +
                            result.append((ocr_info_head, ocr_info_tail))
         | 
| 35 | 
            +
                        results.append(result)
         | 
| 36 | 
            +
                    return results
         | 
    	
        ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import paddle
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def load_vqa_bio_label_maps(label_map_path):
         | 
| 6 | 
            +
                with open(label_map_path, "r", encoding="utf-8") as fin:
         | 
| 7 | 
            +
                    lines = fin.readlines()
         | 
| 8 | 
            +
                lines = [line.strip() for line in lines]
         | 
| 9 | 
            +
                if "O" not in lines:
         | 
| 10 | 
            +
                    lines.insert(0, "O")
         | 
| 11 | 
            +
                labels = []
         | 
| 12 | 
            +
                for line in lines:
         | 
| 13 | 
            +
                    if line == "O":
         | 
| 14 | 
            +
                        labels.append("O")
         | 
| 15 | 
            +
                    else:
         | 
| 16 | 
            +
                        labels.append("B-" + line)
         | 
| 17 | 
            +
                        labels.append("I-" + line)
         | 
| 18 | 
            +
                label2id_map = {label: idx for idx, label in enumerate(labels)}
         | 
| 19 | 
            +
                id2label_map = {idx: label for idx, label in enumerate(labels)}
         | 
| 20 | 
            +
                return label2id_map, id2label_map
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class VQASerTokenLayoutLMPostProcess(object):
         | 
| 24 | 
            +
                """Convert between text-label and text-index"""
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def __init__(self, class_path, **kwargs):
         | 
| 27 | 
            +
                    super(VQASerTokenLayoutLMPostProcess, self).__init__()
         | 
| 28 | 
            +
                    label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    self.label2id_map_for_draw = dict()
         | 
| 31 | 
            +
                    for key in label2id_map:
         | 
| 32 | 
            +
                        if key.startswith("I-"):
         | 
| 33 | 
            +
                            self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
         | 
| 34 | 
            +
                        else:
         | 
| 35 | 
            +
                            self.label2id_map_for_draw[key] = label2id_map[key]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.id2label_map_for_show = dict()
         | 
| 38 | 
            +
                    for key in self.label2id_map_for_draw:
         | 
| 39 | 
            +
                        val = self.label2id_map_for_draw[key]
         | 
| 40 | 
            +
                        if key == "O":
         | 
| 41 | 
            +
                            self.id2label_map_for_show[val] = key
         | 
| 42 | 
            +
                        if key.startswith("B-") or key.startswith("I-"):
         | 
| 43 | 
            +
                            self.id2label_map_for_show[val] = key[2:]
         | 
| 44 | 
            +
                        else:
         | 
| 45 | 
            +
                            self.id2label_map_for_show[val] = key
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __call__(self, preds, batch=None, *args, **kwargs):
         | 
| 48 | 
            +
                    if isinstance(preds, paddle.Tensor):
         | 
| 49 | 
            +
                        preds = preds.numpy()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    if batch is not None:
         | 
| 52 | 
            +
                        return self._metric(preds, batch[1])
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        return self._infer(preds, **kwargs)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def _metric(self, preds, label):
         | 
| 57 | 
            +
                    pred_idxs = preds.argmax(axis=2)
         | 
| 58 | 
            +
                    decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
         | 
| 59 | 
            +
                    label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    for i in range(pred_idxs.shape[0]):
         | 
| 62 | 
            +
                        for j in range(pred_idxs.shape[1]):
         | 
| 63 | 
            +
                            if label[i, j] != -100:
         | 
| 64 | 
            +
                                label_decode_out_list[i].append(self.id2label_map[label[i, j]])
         | 
| 65 | 
            +
                                decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]])
         | 
| 66 | 
            +
                    return decode_out_list, label_decode_out_list
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
         | 
| 69 | 
            +
                    results = []
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    for pred, attention_mask, segment_offset_id, ocr_info in zip(
         | 
| 72 | 
            +
                        preds, attention_masks, segment_offset_ids, ocr_infos
         | 
| 73 | 
            +
                    ):
         | 
| 74 | 
            +
                        pred = np.argmax(pred, axis=1)
         | 
| 75 | 
            +
                        pred = [self.id2label_map[idx] for idx in pred]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        for idx in range(len(segment_offset_id)):
         | 
| 78 | 
            +
                            if idx == 0:
         | 
| 79 | 
            +
                                start_id = 0
         | 
| 80 | 
            +
                            else:
         | 
| 81 | 
            +
                                start_id = segment_offset_id[idx - 1]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                            end_id = segment_offset_id[idx]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                            curr_pred = pred[start_id:end_id]
         | 
| 86 | 
            +
                            curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                            if len(curr_pred) <= 0:
         | 
| 89 | 
            +
                                pred_id = 0
         | 
| 90 | 
            +
                            else:
         | 
| 91 | 
            +
                                counts = np.bincount(curr_pred)
         | 
| 92 | 
            +
                                pred_id = np.argmax(counts)
         | 
| 93 | 
            +
                            ocr_info[idx]["pred_id"] = int(pred_id)
         | 
| 94 | 
            +
                            ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
         | 
| 95 | 
            +
                        results.append(ocr_info)
         | 
| 96 | 
            +
                    return results
         | 
    	
        ocr/ppocr/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ocr/ppocr/data/__init__.py
    ADDED
    
    | @@ -0,0 +1,79 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import signal
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __dir__ = os.path.dirname(os.path.abspath(__file__))
         | 
| 8 | 
            +
            sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import copy
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .imaug import create_operators, transform
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            __all__ = ["build_dataloader", "transform", "create_operators"]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def term_mp(sig_num, frame):
         | 
| 20 | 
            +
                """kill all child processes"""
         | 
| 21 | 
            +
                pid = os.getpid()
         | 
| 22 | 
            +
                pgid = os.getpgid(os.getpid())
         | 
| 23 | 
            +
                print("main proc {} exit, kill process group " "{}".format(pid, pgid))
         | 
| 24 | 
            +
                os.killpg(pgid, signal.SIGKILL)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def build_dataloader(config, mode, device, logger, seed=None):
         | 
| 28 | 
            +
                config = copy.deepcopy(config)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                support_dict = ["SimpleDataSet", "LMDBDataSet", "PGDataSet", "PubTabDataSet"]
         | 
| 31 | 
            +
                module_name = config[mode]["dataset"]["name"]
         | 
| 32 | 
            +
                assert module_name in support_dict, Exception(
         | 
| 33 | 
            +
                    "DataSet only support {}".format(support_dict)
         | 
| 34 | 
            +
                )
         | 
| 35 | 
            +
                assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                dataset = eval(module_name)(config, mode, logger, seed)
         | 
| 38 | 
            +
                loader_config = config[mode]["loader"]
         | 
| 39 | 
            +
                batch_size = loader_config["batch_size_per_card"]
         | 
| 40 | 
            +
                drop_last = loader_config["drop_last"]
         | 
| 41 | 
            +
                shuffle = loader_config["shuffle"]
         | 
| 42 | 
            +
                num_workers = loader_config["num_workers"]
         | 
| 43 | 
            +
                if "use_shared_memory" in loader_config.keys():
         | 
| 44 | 
            +
                    use_shared_memory = loader_config["use_shared_memory"]
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    use_shared_memory = True
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                if mode == "Train":
         | 
| 49 | 
            +
                    # Distribute data to multiple cards
         | 
| 50 | 
            +
                    batch_sampler = DistributedBatchSampler(
         | 
| 51 | 
            +
                        dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    # Distribute data to single card
         | 
| 55 | 
            +
                    batch_sampler = BatchSampler(
         | 
| 56 | 
            +
                        dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                if "collate_fn" in loader_config:
         | 
| 60 | 
            +
                    from . import collate_fn
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    collate_fn = getattr(collate_fn, loader_config["collate_fn"])()
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    collate_fn = None
         | 
| 65 | 
            +
                data_loader = DataLoader(
         | 
| 66 | 
            +
                    dataset=dataset,
         | 
| 67 | 
            +
                    batch_sampler=batch_sampler,
         | 
| 68 | 
            +
                    places=device,
         | 
| 69 | 
            +
                    num_workers=num_workers,
         | 
| 70 | 
            +
                    return_list=True,
         | 
| 71 | 
            +
                    use_shared_memory=use_shared_memory,
         | 
| 72 | 
            +
                    collate_fn=collate_fn,
         | 
| 73 | 
            +
                )
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # support exit using ctrl+c
         | 
| 76 | 
            +
                signal.signal(signal.SIGINT, term_mp)
         | 
| 77 | 
            +
                signal.signal(signal.SIGTERM, term_mp)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                return data_loader
         | 
    	
        ocr/ppocr/data/collate_fn.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numbers
         | 
| 2 | 
            +
            from collections import defaultdict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import paddle
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class DictCollator(object):
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                data batch
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def __call__(self, batch):
         | 
| 14 | 
            +
                    # todo:support batch operators
         | 
| 15 | 
            +
                    data_dict = defaultdict(list)
         | 
| 16 | 
            +
                    to_tensor_keys = []
         | 
| 17 | 
            +
                    for sample in batch:
         | 
| 18 | 
            +
                        for k, v in sample.items():
         | 
| 19 | 
            +
                            if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
         | 
| 20 | 
            +
                                if k not in to_tensor_keys:
         | 
| 21 | 
            +
                                    to_tensor_keys.append(k)
         | 
| 22 | 
            +
                            data_dict[k].append(v)
         | 
| 23 | 
            +
                    for k in to_tensor_keys:
         | 
| 24 | 
            +
                        data_dict[k] = paddle.to_tensor(data_dict[k])
         | 
| 25 | 
            +
                    return data_dict
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class ListCollator(object):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                data batch
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __call__(self, batch):
         | 
| 34 | 
            +
                    # todo:support batch operators
         | 
| 35 | 
            +
                    data_dict = defaultdict(list)
         | 
| 36 | 
            +
                    to_tensor_idxs = []
         | 
| 37 | 
            +
                    for sample in batch:
         | 
| 38 | 
            +
                        for idx, v in enumerate(sample):
         | 
| 39 | 
            +
                            if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
         | 
| 40 | 
            +
                                if idx not in to_tensor_idxs:
         | 
| 41 | 
            +
                                    to_tensor_idxs.append(idx)
         | 
| 42 | 
            +
                            data_dict[idx].append(v)
         | 
| 43 | 
            +
                    for idx in to_tensor_idxs:
         | 
| 44 | 
            +
                        data_dict[idx] = paddle.to_tensor(data_dict[idx])
         | 
| 45 | 
            +
                    return list(data_dict.values())
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class SSLRotateCollate(object):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                bach: [
         | 
| 51 | 
            +
                    [(4*3xH*W), (4,)]
         | 
| 52 | 
            +
                    [(4*3xH*W), (4,)]
         | 
| 53 | 
            +
                    ...
         | 
| 54 | 
            +
                ]
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __call__(self, batch):
         | 
| 58 | 
            +
                    output = [np.concatenate(d, axis=0) for d in zip(*batch)]
         | 
| 59 | 
            +
                    return output
         | 
    	
        ocr/ppocr/data/imaug/ColorJitter.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from paddle.vision.transforms import ColorJitter as pp_ColorJitter
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = ["ColorJitter"]
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class ColorJitter(object):
         | 
| 7 | 
            +
                def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs):
         | 
| 8 | 
            +
                    self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def __call__(self, data):
         | 
| 11 | 
            +
                    image = data["image"]
         | 
| 12 | 
            +
                    image = self.aug(image)
         | 
| 13 | 
            +
                    data["image"] = image
         | 
| 14 | 
            +
                    return data
         | 
    	
        ocr/ppocr/data/imaug/__init__.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .ColorJitter import ColorJitter
         | 
| 4 | 
            +
            from .copy_paste import CopyPaste
         | 
| 5 | 
            +
            from .east_process import *
         | 
| 6 | 
            +
            from .fce_aug import *
         | 
| 7 | 
            +
            from .fce_targets import FCENetTargets
         | 
| 8 | 
            +
            from .gen_table_mask import *
         | 
| 9 | 
            +
            from .iaa_augment import IaaAugment
         | 
| 10 | 
            +
            from .label_ops import *
         | 
| 11 | 
            +
            from .make_border_map import MakeBorderMap
         | 
| 12 | 
            +
            from .make_pse_gt import MakePseGt
         | 
| 13 | 
            +
            from .make_shrink_map import MakeShrinkMap
         | 
| 14 | 
            +
            from .operators import *
         | 
| 15 | 
            +
            from .pg_process import *
         | 
| 16 | 
            +
            from .randaugment import RandAugment
         | 
| 17 | 
            +
            from .random_crop_data import EastRandomCropData, RandomCropImgMask
         | 
| 18 | 
            +
            from .rec_img_aug import (
         | 
| 19 | 
            +
                ClsResizeImg,
         | 
| 20 | 
            +
                NRTRRecResizeImg,
         | 
| 21 | 
            +
                PRENResizeImg,
         | 
| 22 | 
            +
                RecAug,
         | 
| 23 | 
            +
                RecConAug,
         | 
| 24 | 
            +
                RecResizeImg,
         | 
| 25 | 
            +
                SARRecResizeImg,
         | 
| 26 | 
            +
                SRNRecResizeImg,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from .sast_process import *
         | 
| 29 | 
            +
            from .ssl_img_aug import SSLRotateResize
         | 
| 30 | 
            +
            from .vqa import *
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def transform(data, ops=None):
         | 
| 34 | 
            +
                """transform"""
         | 
| 35 | 
            +
                if ops is None:
         | 
| 36 | 
            +
                    ops = []
         | 
| 37 | 
            +
                for op in ops:
         | 
| 38 | 
            +
                    data = op(data)
         | 
| 39 | 
            +
                    if data is None:
         | 
| 40 | 
            +
                        return None
         | 
| 41 | 
            +
                return data
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def create_operators(op_param_list, global_config=None):
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                create operators based on the config
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                Args:
         | 
| 49 | 
            +
                    params(list): a dict list, used to create some operators
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                assert isinstance(op_param_list, list), "operator config should be a list"
         | 
| 52 | 
            +
                ops = []
         | 
| 53 | 
            +
                for operator in op_param_list:
         | 
| 54 | 
            +
                    assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
         | 
| 55 | 
            +
                    op_name = list(operator)[0]
         | 
| 56 | 
            +
                    param = {} if operator[op_name] is None else operator[op_name]
         | 
| 57 | 
            +
                    if global_config is not None:
         | 
| 58 | 
            +
                        param.update(global_config)
         | 
| 59 | 
            +
                    op = eval(op_name)(**param)
         | 
| 60 | 
            +
                    ops.append(op)
         | 
| 61 | 
            +
                return ops
         | 
    	
        ocr/ppocr/data/imaug/copy_paste.py
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            from shapely.geometry import Polygon
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ppocr.data.imaug.iaa_augment import IaaAugment
         | 
| 10 | 
            +
            from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
         | 
| 11 | 
            +
            from utility import get_rotate_crop_image
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class CopyPaste(object):
         | 
| 15 | 
            +
                def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
         | 
| 16 | 
            +
                    self.ext_data_num = 1
         | 
| 17 | 
            +
                    self.objects_paste_ratio = objects_paste_ratio
         | 
| 18 | 
            +
                    self.limit_paste = limit_paste
         | 
| 19 | 
            +
                    augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}]
         | 
| 20 | 
            +
                    self.aug = IaaAugment(augmenter_args)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __call__(self, data):
         | 
| 23 | 
            +
                    point_num = data["polys"].shape[1]
         | 
| 24 | 
            +
                    src_img = data["image"]
         | 
| 25 | 
            +
                    src_polys = data["polys"].tolist()
         | 
| 26 | 
            +
                    src_texts = data["texts"]
         | 
| 27 | 
            +
                    src_ignores = data["ignore_tags"].tolist()
         | 
| 28 | 
            +
                    ext_data = data["ext_data"][0]
         | 
| 29 | 
            +
                    ext_image = ext_data["image"]
         | 
| 30 | 
            +
                    ext_polys = ext_data["polys"]
         | 
| 31 | 
            +
                    ext_texts = ext_data["texts"]
         | 
| 32 | 
            +
                    ext_ignores = ext_data["ignore_tags"]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
         | 
| 35 | 
            +
                    select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    random.shuffle(indexs)
         | 
| 38 | 
            +
                    select_idxs = indexs[:select_num]
         | 
| 39 | 
            +
                    select_polys = ext_polys[select_idxs]
         | 
| 40 | 
            +
                    select_ignores = ext_ignores[select_idxs]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
         | 
| 43 | 
            +
                    ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
         | 
| 44 | 
            +
                    src_img = Image.fromarray(src_img).convert("RGBA")
         | 
| 45 | 
            +
                    for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
         | 
| 46 | 
            +
                        box_img = get_rotate_crop_image(ext_image, poly)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                        src_img, box = self.paste_img(src_img, box_img, src_polys)
         | 
| 49 | 
            +
                        if box is not None:
         | 
| 50 | 
            +
                            box = box.tolist()
         | 
| 51 | 
            +
                            for _ in range(len(box), point_num):
         | 
| 52 | 
            +
                                box.append(box[-1])
         | 
| 53 | 
            +
                            src_polys.append(box)
         | 
| 54 | 
            +
                            src_texts.append(ext_texts[idx])
         | 
| 55 | 
            +
                            src_ignores.append(tag)
         | 
| 56 | 
            +
                    src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
         | 
| 57 | 
            +
                    h, w = src_img.shape[:2]
         | 
| 58 | 
            +
                    src_polys = np.array(src_polys)
         | 
| 59 | 
            +
                    src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
         | 
| 60 | 
            +
                    src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
         | 
| 61 | 
            +
                    data["image"] = src_img
         | 
| 62 | 
            +
                    data["polys"] = src_polys
         | 
| 63 | 
            +
                    data["texts"] = src_texts
         | 
| 64 | 
            +
                    data["ignore_tags"] = np.array(src_ignores)
         | 
| 65 | 
            +
                    return data
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def paste_img(self, src_img, box_img, src_polys):
         | 
| 68 | 
            +
                    box_img_pil = Image.fromarray(box_img).convert("RGBA")
         | 
| 69 | 
            +
                    src_w, src_h = src_img.size
         | 
| 70 | 
            +
                    box_w, box_h = box_img_pil.size
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    angle = np.random.randint(0, 360)
         | 
| 73 | 
            +
                    box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
         | 
| 74 | 
            +
                    box = rotate_bbox(box_img, box, angle)[0]
         | 
| 75 | 
            +
                    box_img_pil = box_img_pil.rotate(angle, expand=1)
         | 
| 76 | 
            +
                    box_w, box_h = box_img_pil.width, box_img_pil.height
         | 
| 77 | 
            +
                    if src_w - box_w < 0 or src_h - box_h < 0:
         | 
| 78 | 
            +
                        return src_img, None
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    paste_x, paste_y = self.select_coord(
         | 
| 81 | 
            +
                        src_polys, box, src_w - box_w, src_h - box_h
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    if paste_x is None:
         | 
| 84 | 
            +
                        return src_img, None
         | 
| 85 | 
            +
                    box[:, 0] += paste_x
         | 
| 86 | 
            +
                    box[:, 1] += paste_y
         | 
| 87 | 
            +
                    r, g, b, A = box_img_pil.split()
         | 
| 88 | 
            +
                    src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    return src_img, box
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def select_coord(self, src_polys, box, endx, endy):
         | 
| 93 | 
            +
                    if self.limit_paste:
         | 
| 94 | 
            +
                        xmin, ymin, xmax, ymax = (
         | 
| 95 | 
            +
                            box[:, 0].min(),
         | 
| 96 | 
            +
                            box[:, 1].min(),
         | 
| 97 | 
            +
                            box[:, 0].max(),
         | 
| 98 | 
            +
                            box[:, 1].max(),
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
                        for _ in range(50):
         | 
| 101 | 
            +
                            paste_x = random.randint(0, endx)
         | 
| 102 | 
            +
                            paste_y = random.randint(0, endy)
         | 
| 103 | 
            +
                            xmin1 = xmin + paste_x
         | 
| 104 | 
            +
                            xmax1 = xmax + paste_x
         | 
| 105 | 
            +
                            ymin1 = ymin + paste_y
         | 
| 106 | 
            +
                            ymax1 = ymax + paste_y
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                            num_poly_in_rect = 0
         | 
| 109 | 
            +
                            for poly in src_polys:
         | 
| 110 | 
            +
                                if not is_poly_outside_rect(
         | 
| 111 | 
            +
                                    poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1
         | 
| 112 | 
            +
                                ):
         | 
| 113 | 
            +
                                    num_poly_in_rect += 1
         | 
| 114 | 
            +
                                    break
         | 
| 115 | 
            +
                            if num_poly_in_rect == 0:
         | 
| 116 | 
            +
                                return paste_x, paste_y
         | 
| 117 | 
            +
                        return None, None
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        paste_x = random.randint(0, endx)
         | 
| 120 | 
            +
                        paste_y = random.randint(0, endy)
         | 
| 121 | 
            +
                        return paste_x, paste_y
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def get_union(pD, pG):
         | 
| 125 | 
            +
                return Polygon(pD).union(Polygon(pG)).area
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def get_intersection_over_union(pD, pG):
         | 
| 129 | 
            +
                return get_intersection(pD, pG) / get_union(pD, pG)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def get_intersection(pD, pG):
         | 
| 133 | 
            +
                return Polygon(pD).intersection(Polygon(pG)).area
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            def rotate_bbox(img, text_polys, angle, scale=1):
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
         | 
| 139 | 
            +
                Args:
         | 
| 140 | 
            +
                    img: np.ndarray
         | 
| 141 | 
            +
                    text_polys: np.ndarray N*4*2
         | 
| 142 | 
            +
                    angle: int
         | 
| 143 | 
            +
                    scale: int
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                Returns:
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                """
         | 
| 148 | 
            +
                w = img.shape[1]
         | 
| 149 | 
            +
                h = img.shape[0]
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                rangle = np.deg2rad(angle)
         | 
| 152 | 
            +
                nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
         | 
| 153 | 
            +
                nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
         | 
| 154 | 
            +
                rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
         | 
| 155 | 
            +
                rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
         | 
| 156 | 
            +
                rot_mat[0, 2] += rot_move[0]
         | 
| 157 | 
            +
                rot_mat[1, 2] += rot_move[1]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # ---------------------- rotate box ----------------------
         | 
| 160 | 
            +
                rot_text_polys = list()
         | 
| 161 | 
            +
                for bbox in text_polys:
         | 
| 162 | 
            +
                    point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
         | 
| 163 | 
            +
                    point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
         | 
| 164 | 
            +
                    point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
         | 
| 165 | 
            +
                    point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
         | 
| 166 | 
            +
                    rot_text_polys.append([point1, point2, point3, point4])
         | 
| 167 | 
            +
                return np.array(rot_text_polys, dtype=np.float32)
         | 
    	
        ocr/ppocr/data/imaug/east_process.py
    ADDED
    
    | @@ -0,0 +1,427 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __all__ = ["EASTProcessTrain"]
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class EASTProcessTrain(object):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    image_shape=[512, 512],
         | 
| 13 | 
            +
                    background_ratio=0.125,
         | 
| 14 | 
            +
                    min_crop_side_ratio=0.1,
         | 
| 15 | 
            +
                    min_text_size=10,
         | 
| 16 | 
            +
                    **kwargs
         | 
| 17 | 
            +
                ):
         | 
| 18 | 
            +
                    self.input_size = image_shape[1]
         | 
| 19 | 
            +
                    self.random_scale = np.array([0.5, 1, 2.0, 3.0])
         | 
| 20 | 
            +
                    self.background_ratio = background_ratio
         | 
| 21 | 
            +
                    self.min_crop_side_ratio = min_crop_side_ratio
         | 
| 22 | 
            +
                    self.min_text_size = min_text_size
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def preprocess(self, im):
         | 
| 25 | 
            +
                    input_size = self.input_size
         | 
| 26 | 
            +
                    im_shape = im.shape
         | 
| 27 | 
            +
                    im_size_min = np.min(im_shape[0:2])
         | 
| 28 | 
            +
                    im_size_max = np.max(im_shape[0:2])
         | 
| 29 | 
            +
                    im_scale = float(input_size) / float(im_size_max)
         | 
| 30 | 
            +
                    im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
         | 
| 31 | 
            +
                    img_mean = [0.485, 0.456, 0.406]
         | 
| 32 | 
            +
                    img_std = [0.229, 0.224, 0.225]
         | 
| 33 | 
            +
                    # im = im[:, :, ::-1].astype(np.float32)
         | 
| 34 | 
            +
                    im = im / 255
         | 
| 35 | 
            +
                    im -= img_mean
         | 
| 36 | 
            +
                    im /= img_std
         | 
| 37 | 
            +
                    new_h, new_w, _ = im.shape
         | 
| 38 | 
            +
                    im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
         | 
| 39 | 
            +
                    im_padded[:new_h, :new_w, :] = im
         | 
| 40 | 
            +
                    im_padded = im_padded.transpose((2, 0, 1))
         | 
| 41 | 
            +
                    im_padded = im_padded[np.newaxis, :]
         | 
| 42 | 
            +
                    return im_padded, im_scale
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def rotate_im_poly(self, im, text_polys):
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    rotate image with 90 / 180 / 270 degre
         | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    im_w, im_h = im.shape[1], im.shape[0]
         | 
| 49 | 
            +
                    dst_im = im.copy()
         | 
| 50 | 
            +
                    dst_polys = []
         | 
| 51 | 
            +
                    rand_degree_ratio = np.random.rand()
         | 
| 52 | 
            +
                    rand_degree_cnt = 1
         | 
| 53 | 
            +
                    if 0.333 < rand_degree_ratio < 0.666:
         | 
| 54 | 
            +
                        rand_degree_cnt = 2
         | 
| 55 | 
            +
                    elif rand_degree_ratio > 0.666:
         | 
| 56 | 
            +
                        rand_degree_cnt = 3
         | 
| 57 | 
            +
                    for i in range(rand_degree_cnt):
         | 
| 58 | 
            +
                        dst_im = np.rot90(dst_im)
         | 
| 59 | 
            +
                    rot_degree = -90 * rand_degree_cnt
         | 
| 60 | 
            +
                    rot_angle = rot_degree * math.pi / 180.0
         | 
| 61 | 
            +
                    n_poly = text_polys.shape[0]
         | 
| 62 | 
            +
                    cx, cy = 0.5 * im_w, 0.5 * im_h
         | 
| 63 | 
            +
                    ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
         | 
| 64 | 
            +
                    for i in range(n_poly):
         | 
| 65 | 
            +
                        wordBB = text_polys[i]
         | 
| 66 | 
            +
                        poly = []
         | 
| 67 | 
            +
                        for j in range(4):
         | 
| 68 | 
            +
                            sx, sy = wordBB[j][0], wordBB[j][1]
         | 
| 69 | 
            +
                            dx = (
         | 
| 70 | 
            +
                                math.cos(rot_angle) * (sx - cx)
         | 
| 71 | 
            +
                                - math.sin(rot_angle) * (sy - cy)
         | 
| 72 | 
            +
                                + ncx
         | 
| 73 | 
            +
                            )
         | 
| 74 | 
            +
                            dy = (
         | 
| 75 | 
            +
                                math.sin(rot_angle) * (sx - cx)
         | 
| 76 | 
            +
                                + math.cos(rot_angle) * (sy - cy)
         | 
| 77 | 
            +
                                + ncy
         | 
| 78 | 
            +
                            )
         | 
| 79 | 
            +
                            poly.append([dx, dy])
         | 
| 80 | 
            +
                        dst_polys.append(poly)
         | 
| 81 | 
            +
                    dst_polys = np.array(dst_polys, dtype=np.float32)
         | 
| 82 | 
            +
                    return dst_im, dst_polys
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def polygon_area(self, poly):
         | 
| 85 | 
            +
                    """
         | 
| 86 | 
            +
                    compute area of a polygon
         | 
| 87 | 
            +
                    :param poly:
         | 
| 88 | 
            +
                    :return:
         | 
| 89 | 
            +
                    """
         | 
| 90 | 
            +
                    edge = [
         | 
| 91 | 
            +
                        (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
         | 
| 92 | 
            +
                        (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
         | 
| 93 | 
            +
                        (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
         | 
| 94 | 
            +
                        (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
         | 
| 95 | 
            +
                    ]
         | 
| 96 | 
            +
                    return np.sum(edge) / 2.0
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def check_and_validate_polys(self, polys, tags, img_height, img_width):
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    check so that the text poly is in the same direction,
         | 
| 101 | 
            +
                    and also filter some invalid polygons
         | 
| 102 | 
            +
                    :param polys:
         | 
| 103 | 
            +
                    :param tags:
         | 
| 104 | 
            +
                    :return:
         | 
| 105 | 
            +
                    """
         | 
| 106 | 
            +
                    h, w = img_height, img_width
         | 
| 107 | 
            +
                    if polys.shape[0] == 0:
         | 
| 108 | 
            +
                        return polys
         | 
| 109 | 
            +
                    polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
         | 
| 110 | 
            +
                    polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    validated_polys = []
         | 
| 113 | 
            +
                    validated_tags = []
         | 
| 114 | 
            +
                    for poly, tag in zip(polys, tags):
         | 
| 115 | 
            +
                        p_area = self.polygon_area(poly)
         | 
| 116 | 
            +
                        # invalid poly
         | 
| 117 | 
            +
                        if abs(p_area) < 1:
         | 
| 118 | 
            +
                            continue
         | 
| 119 | 
            +
                        if p_area > 0:
         | 
| 120 | 
            +
                            #'poly in wrong direction'
         | 
| 121 | 
            +
                            if not tag:
         | 
| 122 | 
            +
                                tag = True  # reversed cases should be ignore
         | 
| 123 | 
            +
                            poly = poly[(0, 3, 2, 1), :]
         | 
| 124 | 
            +
                        validated_polys.append(poly)
         | 
| 125 | 
            +
                        validated_tags.append(tag)
         | 
| 126 | 
            +
                    return np.array(validated_polys), np.array(validated_tags)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def draw_img_polys(self, img, polys):
         | 
| 129 | 
            +
                    if len(img.shape) == 4:
         | 
| 130 | 
            +
                        img = np.squeeze(img, axis=0)
         | 
| 131 | 
            +
                    if img.shape[0] == 3:
         | 
| 132 | 
            +
                        img = img.transpose((1, 2, 0))
         | 
| 133 | 
            +
                        img[:, :, 2] += 123.68
         | 
| 134 | 
            +
                        img[:, :, 1] += 116.78
         | 
| 135 | 
            +
                        img[:, :, 0] += 103.94
         | 
| 136 | 
            +
                    cv2.imwrite("tmp.jpg", img)
         | 
| 137 | 
            +
                    img = cv2.imread("tmp.jpg")
         | 
| 138 | 
            +
                    for box in polys:
         | 
| 139 | 
            +
                        box = box.astype(np.int32).reshape((-1, 1, 2))
         | 
| 140 | 
            +
                        cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
         | 
| 141 | 
            +
                    import random
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    ino = random.randint(0, 100)
         | 
| 144 | 
            +
                    cv2.imwrite("tmp_%d.jpg" % ino, img)
         | 
| 145 | 
            +
                    return
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def shrink_poly(self, poly, r):
         | 
| 148 | 
            +
                    """
         | 
| 149 | 
            +
                    fit a poly inside the origin poly, maybe bugs here...
         | 
| 150 | 
            +
                    used for generate the score map
         | 
| 151 | 
            +
                    :param poly: the text poly
         | 
| 152 | 
            +
                    :param r: r in the paper
         | 
| 153 | 
            +
                    :return: the shrinked poly
         | 
| 154 | 
            +
                    """
         | 
| 155 | 
            +
                    # shrink ratio
         | 
| 156 | 
            +
                    R = 0.3
         | 
| 157 | 
            +
                    # find the longer pair
         | 
| 158 | 
            +
                    dist0 = np.linalg.norm(poly[0] - poly[1])
         | 
| 159 | 
            +
                    dist1 = np.linalg.norm(poly[2] - poly[3])
         | 
| 160 | 
            +
                    dist2 = np.linalg.norm(poly[0] - poly[3])
         | 
| 161 | 
            +
                    dist3 = np.linalg.norm(poly[1] - poly[2])
         | 
| 162 | 
            +
                    if dist0 + dist1 > dist2 + dist3:
         | 
| 163 | 
            +
                        # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
         | 
| 164 | 
            +
                        ## p0, p1
         | 
| 165 | 
            +
                        theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
         | 
| 166 | 
            +
                        poly[0][0] += R * r[0] * np.cos(theta)
         | 
| 167 | 
            +
                        poly[0][1] += R * r[0] * np.sin(theta)
         | 
| 168 | 
            +
                        poly[1][0] -= R * r[1] * np.cos(theta)
         | 
| 169 | 
            +
                        poly[1][1] -= R * r[1] * np.sin(theta)
         | 
| 170 | 
            +
                        ## p2, p3
         | 
| 171 | 
            +
                        theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
         | 
| 172 | 
            +
                        poly[3][0] += R * r[3] * np.cos(theta)
         | 
| 173 | 
            +
                        poly[3][1] += R * r[3] * np.sin(theta)
         | 
| 174 | 
            +
                        poly[2][0] -= R * r[2] * np.cos(theta)
         | 
| 175 | 
            +
                        poly[2][1] -= R * r[2] * np.sin(theta)
         | 
| 176 | 
            +
                        ## p0, p3
         | 
| 177 | 
            +
                        theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
         | 
| 178 | 
            +
                        poly[0][0] += R * r[0] * np.sin(theta)
         | 
| 179 | 
            +
                        poly[0][1] += R * r[0] * np.cos(theta)
         | 
| 180 | 
            +
                        poly[3][0] -= R * r[3] * np.sin(theta)
         | 
| 181 | 
            +
                        poly[3][1] -= R * r[3] * np.cos(theta)
         | 
| 182 | 
            +
                        ## p1, p2
         | 
| 183 | 
            +
                        theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
         | 
| 184 | 
            +
                        poly[1][0] += R * r[1] * np.sin(theta)
         | 
| 185 | 
            +
                        poly[1][1] += R * r[1] * np.cos(theta)
         | 
| 186 | 
            +
                        poly[2][0] -= R * r[2] * np.sin(theta)
         | 
| 187 | 
            +
                        poly[2][1] -= R * r[2] * np.cos(theta)
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        ## p0, p3
         | 
| 190 | 
            +
                        # print poly
         | 
| 191 | 
            +
                        theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
         | 
| 192 | 
            +
                        poly[0][0] += R * r[0] * np.sin(theta)
         | 
| 193 | 
            +
                        poly[0][1] += R * r[0] * np.cos(theta)
         | 
| 194 | 
            +
                        poly[3][0] -= R * r[3] * np.sin(theta)
         | 
| 195 | 
            +
                        poly[3][1] -= R * r[3] * np.cos(theta)
         | 
| 196 | 
            +
                        ## p1, p2
         | 
| 197 | 
            +
                        theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
         | 
| 198 | 
            +
                        poly[1][0] += R * r[1] * np.sin(theta)
         | 
| 199 | 
            +
                        poly[1][1] += R * r[1] * np.cos(theta)
         | 
| 200 | 
            +
                        poly[2][0] -= R * r[2] * np.sin(theta)
         | 
| 201 | 
            +
                        poly[2][1] -= R * r[2] * np.cos(theta)
         | 
| 202 | 
            +
                        ## p0, p1
         | 
| 203 | 
            +
                        theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
         | 
| 204 | 
            +
                        poly[0][0] += R * r[0] * np.cos(theta)
         | 
| 205 | 
            +
                        poly[0][1] += R * r[0] * np.sin(theta)
         | 
| 206 | 
            +
                        poly[1][0] -= R * r[1] * np.cos(theta)
         | 
| 207 | 
            +
                        poly[1][1] -= R * r[1] * np.sin(theta)
         | 
| 208 | 
            +
                        ## p2, p3
         | 
| 209 | 
            +
                        theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
         | 
| 210 | 
            +
                        poly[3][0] += R * r[3] * np.cos(theta)
         | 
| 211 | 
            +
                        poly[3][1] += R * r[3] * np.sin(theta)
         | 
| 212 | 
            +
                        poly[2][0] -= R * r[2] * np.cos(theta)
         | 
| 213 | 
            +
                        poly[2][1] -= R * r[2] * np.sin(theta)
         | 
| 214 | 
            +
                    return poly
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def generate_quad(self, im_size, polys, tags):
         | 
| 217 | 
            +
                    """
         | 
| 218 | 
            +
                    Generate quadrangle.
         | 
| 219 | 
            +
                    """
         | 
| 220 | 
            +
                    h, w = im_size
         | 
| 221 | 
            +
                    poly_mask = np.zeros((h, w), dtype=np.uint8)
         | 
| 222 | 
            +
                    score_map = np.zeros((h, w), dtype=np.uint8)
         | 
| 223 | 
            +
                    # (x1, y1, ..., x4, y4, short_edge_norm)
         | 
| 224 | 
            +
                    geo_map = np.zeros((h, w, 9), dtype=np.float32)
         | 
| 225 | 
            +
                    # mask used during traning, to ignore some hard areas
         | 
| 226 | 
            +
                    training_mask = np.ones((h, w), dtype=np.uint8)
         | 
| 227 | 
            +
                    for poly_idx, poly_tag in enumerate(zip(polys, tags)):
         | 
| 228 | 
            +
                        poly = poly_tag[0]
         | 
| 229 | 
            +
                        tag = poly_tag[1]
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                        r = [None, None, None, None]
         | 
| 232 | 
            +
                        for i in range(4):
         | 
| 233 | 
            +
                            dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
         | 
| 234 | 
            +
                            dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
         | 
| 235 | 
            +
                            r[i] = min(dist1, dist2)
         | 
| 236 | 
            +
                        # score map
         | 
| 237 | 
            +
                        shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[
         | 
| 238 | 
            +
                            np.newaxis, :, :
         | 
| 239 | 
            +
                        ]
         | 
| 240 | 
            +
                        cv2.fillPoly(score_map, shrinked_poly, 1)
         | 
| 241 | 
            +
                        cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
         | 
| 242 | 
            +
                        # if the poly is too small, then ignore it during training
         | 
| 243 | 
            +
                        poly_h = min(
         | 
| 244 | 
            +
                            np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])
         | 
| 245 | 
            +
                        )
         | 
| 246 | 
            +
                        poly_w = min(
         | 
| 247 | 
            +
                            np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
                        if min(poly_h, poly_w) < self.min_text_size:
         | 
| 250 | 
            +
                            cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        if tag:
         | 
| 253 | 
            +
                            cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                        xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
         | 
| 256 | 
            +
                        # geo map.
         | 
| 257 | 
            +
                        y_in_poly = xy_in_poly[:, 0]
         | 
| 258 | 
            +
                        x_in_poly = xy_in_poly[:, 1]
         | 
| 259 | 
            +
                        poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
         | 
| 260 | 
            +
                        poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
         | 
| 261 | 
            +
                        for pno in range(4):
         | 
| 262 | 
            +
                            geo_channel_beg = pno * 2
         | 
| 263 | 
            +
                            geo_map[y_in_poly, x_in_poly, geo_channel_beg] = (
         | 
| 264 | 
            +
                                x_in_poly - poly[pno, 0]
         | 
| 265 | 
            +
                            )
         | 
| 266 | 
            +
                            geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = (
         | 
| 267 | 
            +
                                y_in_poly - poly[pno, 1]
         | 
| 268 | 
            +
                            )
         | 
| 269 | 
            +
                        geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0)
         | 
| 270 | 
            +
                    return score_map, geo_map, training_mask
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    make random crop from the input image
         | 
| 275 | 
            +
                    :param im:
         | 
| 276 | 
            +
                    :param polys:
         | 
| 277 | 
            +
                    :param tags:
         | 
| 278 | 
            +
                    :param crop_background:
         | 
| 279 | 
            +
                    :param max_tries:
         | 
| 280 | 
            +
                    :return:
         | 
| 281 | 
            +
                    """
         | 
| 282 | 
            +
                    h, w, _ = im.shape
         | 
| 283 | 
            +
                    pad_h = h // 10
         | 
| 284 | 
            +
                    pad_w = w // 10
         | 
| 285 | 
            +
                    h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
         | 
| 286 | 
            +
                    w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
         | 
| 287 | 
            +
                    for poly in polys:
         | 
| 288 | 
            +
                        poly = np.round(poly, decimals=0).astype(np.int32)
         | 
| 289 | 
            +
                        minx = np.min(poly[:, 0])
         | 
| 290 | 
            +
                        maxx = np.max(poly[:, 0])
         | 
| 291 | 
            +
                        w_array[minx + pad_w : maxx + pad_w] = 1
         | 
| 292 | 
            +
                        miny = np.min(poly[:, 1])
         | 
| 293 | 
            +
                        maxy = np.max(poly[:, 1])
         | 
| 294 | 
            +
                        h_array[miny + pad_h : maxy + pad_h] = 1
         | 
| 295 | 
            +
                    # ensure the cropped area not across a text
         | 
| 296 | 
            +
                    h_axis = np.where(h_array == 0)[0]
         | 
| 297 | 
            +
                    w_axis = np.where(w_array == 0)[0]
         | 
| 298 | 
            +
                    if len(h_axis) == 0 or len(w_axis) == 0:
         | 
| 299 | 
            +
                        return im, polys, tags
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    for i in range(max_tries):
         | 
| 302 | 
            +
                        xx = np.random.choice(w_axis, size=2)
         | 
| 303 | 
            +
                        xmin = np.min(xx) - pad_w
         | 
| 304 | 
            +
                        xmax = np.max(xx) - pad_w
         | 
| 305 | 
            +
                        xmin = np.clip(xmin, 0, w - 1)
         | 
| 306 | 
            +
                        xmax = np.clip(xmax, 0, w - 1)
         | 
| 307 | 
            +
                        yy = np.random.choice(h_axis, size=2)
         | 
| 308 | 
            +
                        ymin = np.min(yy) - pad_h
         | 
| 309 | 
            +
                        ymax = np.max(yy) - pad_h
         | 
| 310 | 
            +
                        ymin = np.clip(ymin, 0, h - 1)
         | 
| 311 | 
            +
                        ymax = np.clip(ymax, 0, h - 1)
         | 
| 312 | 
            +
                        if (
         | 
| 313 | 
            +
                            xmax - xmin < self.min_crop_side_ratio * w
         | 
| 314 | 
            +
                            or ymax - ymin < self.min_crop_side_ratio * h
         | 
| 315 | 
            +
                        ):
         | 
| 316 | 
            +
                            # area too small
         | 
| 317 | 
            +
                            continue
         | 
| 318 | 
            +
                        if polys.shape[0] != 0:
         | 
| 319 | 
            +
                            poly_axis_in_area = (
         | 
| 320 | 
            +
                                (polys[:, :, 0] >= xmin)
         | 
| 321 | 
            +
                                & (polys[:, :, 0] <= xmax)
         | 
| 322 | 
            +
                                & (polys[:, :, 1] >= ymin)
         | 
| 323 | 
            +
                                & (polys[:, :, 1] <= ymax)
         | 
| 324 | 
            +
                            )
         | 
| 325 | 
            +
                            selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
         | 
| 326 | 
            +
                        else:
         | 
| 327 | 
            +
                            selected_polys = []
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                        if len(selected_polys) == 0:
         | 
| 330 | 
            +
                            # no text in this area
         | 
| 331 | 
            +
                            if crop_background:
         | 
| 332 | 
            +
                                im = im[ymin : ymax + 1, xmin : xmax + 1, :]
         | 
| 333 | 
            +
                                polys = []
         | 
| 334 | 
            +
                                tags = []
         | 
| 335 | 
            +
                                return im, polys, tags
         | 
| 336 | 
            +
                            else:
         | 
| 337 | 
            +
                                continue
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        im = im[ymin : ymax + 1, xmin : xmax + 1, :]
         | 
| 340 | 
            +
                        polys = polys[selected_polys]
         | 
| 341 | 
            +
                        tags = tags[selected_polys]
         | 
| 342 | 
            +
                        polys[:, :, 0] -= xmin
         | 
| 343 | 
            +
                        polys[:, :, 1] -= ymin
         | 
| 344 | 
            +
                        return im, polys, tags
         | 
| 345 | 
            +
                    return im, polys, tags
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def crop_background_infor(self, im, text_polys, text_tags):
         | 
| 348 | 
            +
                    im, text_polys, text_tags = self.crop_area(
         | 
| 349 | 
            +
                        im, text_polys, text_tags, crop_background=True
         | 
| 350 | 
            +
                    )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    if len(text_polys) > 0:
         | 
| 353 | 
            +
                        return None
         | 
| 354 | 
            +
                    # pad and resize image
         | 
| 355 | 
            +
                    input_size = self.input_size
         | 
| 356 | 
            +
                    im, ratio = self.preprocess(im)
         | 
| 357 | 
            +
                    score_map = np.zeros((input_size, input_size), dtype=np.float32)
         | 
| 358 | 
            +
                    geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
         | 
| 359 | 
            +
                    training_mask = np.ones((input_size, input_size), dtype=np.float32)
         | 
| 360 | 
            +
                    return im, score_map, geo_map, training_mask
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def crop_foreground_infor(self, im, text_polys, text_tags):
         | 
| 363 | 
            +
                    im, text_polys, text_tags = self.crop_area(
         | 
| 364 | 
            +
                        im, text_polys, text_tags, crop_background=False
         | 
| 365 | 
            +
                    )
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if text_polys.shape[0] == 0:
         | 
| 368 | 
            +
                        return None
         | 
| 369 | 
            +
                    # continue for all ignore case
         | 
| 370 | 
            +
                    if np.sum((text_tags * 1.0)) >= text_tags.size:
         | 
| 371 | 
            +
                        return None
         | 
| 372 | 
            +
                    # pad and resize image
         | 
| 373 | 
            +
                    input_size = self.input_size
         | 
| 374 | 
            +
                    im, ratio = self.preprocess(im)
         | 
| 375 | 
            +
                    text_polys[:, :, 0] *= ratio
         | 
| 376 | 
            +
                    text_polys[:, :, 1] *= ratio
         | 
| 377 | 
            +
                    _, _, new_h, new_w = im.shape
         | 
| 378 | 
            +
                    #         print(im.shape)
         | 
| 379 | 
            +
                    #         self.draw_img_polys(im, text_polys)
         | 
| 380 | 
            +
                    score_map, geo_map, training_mask = self.generate_quad(
         | 
| 381 | 
            +
                        (new_h, new_w), text_polys, text_tags
         | 
| 382 | 
            +
                    )
         | 
| 383 | 
            +
                    return im, score_map, geo_map, training_mask
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                def __call__(self, data):
         | 
| 386 | 
            +
                    im = data["image"]
         | 
| 387 | 
            +
                    text_polys = data["polys"]
         | 
| 388 | 
            +
                    text_tags = data["ignore_tags"]
         | 
| 389 | 
            +
                    if im is None:
         | 
| 390 | 
            +
                        return None
         | 
| 391 | 
            +
                    if text_polys.shape[0] == 0:
         | 
| 392 | 
            +
                        return None
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    # add rotate cases
         | 
| 395 | 
            +
                    if np.random.rand() < 0.5:
         | 
| 396 | 
            +
                        im, text_polys = self.rotate_im_poly(im, text_polys)
         | 
| 397 | 
            +
                    h, w, _ = im.shape
         | 
| 398 | 
            +
                    text_polys, text_tags = self.check_and_validate_polys(
         | 
| 399 | 
            +
                        text_polys, text_tags, h, w
         | 
| 400 | 
            +
                    )
         | 
| 401 | 
            +
                    if text_polys.shape[0] == 0:
         | 
| 402 | 
            +
                        return None
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    # random scale this image
         | 
| 405 | 
            +
                    rd_scale = np.random.choice(self.random_scale)
         | 
| 406 | 
            +
                    im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
         | 
| 407 | 
            +
                    text_polys *= rd_scale
         | 
| 408 | 
            +
                    if np.random.rand() < self.background_ratio:
         | 
| 409 | 
            +
                        outs = self.crop_background_infor(im, text_polys, text_tags)
         | 
| 410 | 
            +
                    else:
         | 
| 411 | 
            +
                        outs = self.crop_foreground_infor(im, text_polys, text_tags)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    if outs is None:
         | 
| 414 | 
            +
                        return None
         | 
| 415 | 
            +
                    im, score_map, geo_map, training_mask = outs
         | 
| 416 | 
            +
                    score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
         | 
| 417 | 
            +
                    geo_map = np.swapaxes(geo_map, 1, 2)
         | 
| 418 | 
            +
                    geo_map = np.swapaxes(geo_map, 1, 0)
         | 
| 419 | 
            +
                    geo_map = geo_map[:, ::4, ::4].astype(np.float32)
         | 
| 420 | 
            +
                    training_mask = training_mask[np.newaxis, ::4, ::4]
         | 
| 421 | 
            +
                    training_mask = training_mask.astype(np.float32)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    data["image"] = im[0]
         | 
| 424 | 
            +
                    data["score_map"] = score_map
         | 
| 425 | 
            +
                    data["geo_map"] = geo_map
         | 
| 426 | 
            +
                    data["training_mask"] = training_mask
         | 
| 427 | 
            +
                    return data
         | 
    	
        ocr/ppocr/data/imaug/fce_aug.py
    ADDED
    
    | @@ -0,0 +1,563 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from PIL import Image, ImageDraw
         | 
| 7 | 
            +
            from shapely.geometry import Polygon
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from postprocess.poly_nms import poly_intersection
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class RandomScaling:
         | 
| 13 | 
            +
                def __init__(self, size=800, scale=(3.0 / 4, 5.0 / 2), **kwargs):
         | 
| 14 | 
            +
                    """Random scale the image while keeping aspect.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    Args:
         | 
| 17 | 
            +
                        size (int) : Base size before scaling.
         | 
| 18 | 
            +
                        scale (tuple(float)) : The range of scaling.
         | 
| 19 | 
            +
                    """
         | 
| 20 | 
            +
                    assert isinstance(size, int)
         | 
| 21 | 
            +
                    assert isinstance(scale, float) or isinstance(scale, tuple)
         | 
| 22 | 
            +
                    self.size = size
         | 
| 23 | 
            +
                    self.scale = scale if isinstance(scale, tuple) else (1 - scale, 1 + scale)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __call__(self, data):
         | 
| 26 | 
            +
                    image = data["image"]
         | 
| 27 | 
            +
                    text_polys = data["polys"]
         | 
| 28 | 
            +
                    h, w, _ = image.shape
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
         | 
| 31 | 
            +
                    scales = self.size * 1.0 / max(h, w) * aspect_ratio
         | 
| 32 | 
            +
                    scales = np.array([scales, scales])
         | 
| 33 | 
            +
                    out_size = (int(h * scales[1]), int(w * scales[0]))
         | 
| 34 | 
            +
                    image = cv2.resize(image, out_size[::-1])
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    data["image"] = image
         | 
| 37 | 
            +
                    text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
         | 
| 38 | 
            +
                    text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
         | 
| 39 | 
            +
                    data["polys"] = text_polys
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    return data
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class RandomCropFlip:
         | 
| 45 | 
            +
                def __init__(
         | 
| 46 | 
            +
                    self, pad_ratio=0.1, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2, **kwargs
         | 
| 47 | 
            +
                ):
         | 
| 48 | 
            +
                    """Random crop and flip a patch of the image.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    Args:
         | 
| 51 | 
            +
                        crop_ratio (float): The ratio of cropping.
         | 
| 52 | 
            +
                        iter_num (int): Number of operations.
         | 
| 53 | 
            +
                        min_area_ratio (float): Minimal area ratio between cropped patch
         | 
| 54 | 
            +
                            and original image.
         | 
| 55 | 
            +
                    """
         | 
| 56 | 
            +
                    assert isinstance(crop_ratio, float)
         | 
| 57 | 
            +
                    assert isinstance(iter_num, int)
         | 
| 58 | 
            +
                    assert isinstance(min_area_ratio, float)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.pad_ratio = pad_ratio
         | 
| 61 | 
            +
                    self.epsilon = 1e-2
         | 
| 62 | 
            +
                    self.crop_ratio = crop_ratio
         | 
| 63 | 
            +
                    self.iter_num = iter_num
         | 
| 64 | 
            +
                    self.min_area_ratio = min_area_ratio
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def __call__(self, results):
         | 
| 67 | 
            +
                    for i in range(self.iter_num):
         | 
| 68 | 
            +
                        results = self.random_crop_flip(results)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    return results
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def random_crop_flip(self, results):
         | 
| 73 | 
            +
                    image = results["image"]
         | 
| 74 | 
            +
                    polygons = results["polys"]
         | 
| 75 | 
            +
                    ignore_tags = results["ignore_tags"]
         | 
| 76 | 
            +
                    if len(polygons) == 0:
         | 
| 77 | 
            +
                        return results
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if np.random.random() >= self.crop_ratio:
         | 
| 80 | 
            +
                        return results
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    h, w, _ = image.shape
         | 
| 83 | 
            +
                    area = h * w
         | 
| 84 | 
            +
                    pad_h = int(h * self.pad_ratio)
         | 
| 85 | 
            +
                    pad_w = int(w * self.pad_ratio)
         | 
| 86 | 
            +
                    h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, pad_w)
         | 
| 87 | 
            +
                    if len(h_axis) == 0 or len(w_axis) == 0:
         | 
| 88 | 
            +
                        return results
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    attempt = 0
         | 
| 91 | 
            +
                    while attempt < 50:
         | 
| 92 | 
            +
                        attempt += 1
         | 
| 93 | 
            +
                        polys_keep = []
         | 
| 94 | 
            +
                        polys_new = []
         | 
| 95 | 
            +
                        ignore_tags_keep = []
         | 
| 96 | 
            +
                        ignore_tags_new = []
         | 
| 97 | 
            +
                        xx = np.random.choice(w_axis, size=2)
         | 
| 98 | 
            +
                        xmin = np.min(xx) - pad_w
         | 
| 99 | 
            +
                        xmax = np.max(xx) - pad_w
         | 
| 100 | 
            +
                        xmin = np.clip(xmin, 0, w - 1)
         | 
| 101 | 
            +
                        xmax = np.clip(xmax, 0, w - 1)
         | 
| 102 | 
            +
                        yy = np.random.choice(h_axis, size=2)
         | 
| 103 | 
            +
                        ymin = np.min(yy) - pad_h
         | 
| 104 | 
            +
                        ymax = np.max(yy) - pad_h
         | 
| 105 | 
            +
                        ymin = np.clip(ymin, 0, h - 1)
         | 
| 106 | 
            +
                        ymax = np.clip(ymax, 0, h - 1)
         | 
| 107 | 
            +
                        if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
         | 
| 108 | 
            +
                            # area too small
         | 
| 109 | 
            +
                            continue
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        pts = np.stack(
         | 
| 112 | 
            +
                            [[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]
         | 
| 113 | 
            +
                        ).T.astype(np.int32)
         | 
| 114 | 
            +
                        pp = Polygon(pts)
         | 
| 115 | 
            +
                        fail_flag = False
         | 
| 116 | 
            +
                        for polygon, ignore_tag in zip(polygons, ignore_tags):
         | 
| 117 | 
            +
                            ppi = Polygon(polygon.reshape(-1, 2))
         | 
| 118 | 
            +
                            ppiou, _ = poly_intersection(ppi, pp, buffer=0)
         | 
| 119 | 
            +
                            if (
         | 
| 120 | 
            +
                                np.abs(ppiou - float(ppi.area)) > self.epsilon
         | 
| 121 | 
            +
                                and np.abs(ppiou) > self.epsilon
         | 
| 122 | 
            +
                            ):
         | 
| 123 | 
            +
                                fail_flag = True
         | 
| 124 | 
            +
                                break
         | 
| 125 | 
            +
                            elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
         | 
| 126 | 
            +
                                polys_new.append(polygon)
         | 
| 127 | 
            +
                                ignore_tags_new.append(ignore_tag)
         | 
| 128 | 
            +
                            else:
         | 
| 129 | 
            +
                                polys_keep.append(polygon)
         | 
| 130 | 
            +
                                ignore_tags_keep.append(ignore_tag)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        if fail_flag:
         | 
| 133 | 
            +
                            continue
         | 
| 134 | 
            +
                        else:
         | 
| 135 | 
            +
                            break
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    cropped = image[ymin:ymax, xmin:xmax, :]
         | 
| 138 | 
            +
                    select_type = np.random.randint(3)
         | 
| 139 | 
            +
                    if select_type == 0:
         | 
| 140 | 
            +
                        img = np.ascontiguousarray(cropped[:, ::-1])
         | 
| 141 | 
            +
                    elif select_type == 1:
         | 
| 142 | 
            +
                        img = np.ascontiguousarray(cropped[::-1, :])
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        img = np.ascontiguousarray(cropped[::-1, ::-1])
         | 
| 145 | 
            +
                    image[ymin:ymax, xmin:xmax, :] = img
         | 
| 146 | 
            +
                    results["img"] = image
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if len(polys_new) != 0:
         | 
| 149 | 
            +
                        height, width, _ = cropped.shape
         | 
| 150 | 
            +
                        if select_type == 0:
         | 
| 151 | 
            +
                            for idx, polygon in enumerate(polys_new):
         | 
| 152 | 
            +
                                poly = polygon.reshape(-1, 2)
         | 
| 153 | 
            +
                                poly[:, 0] = width - poly[:, 0] + 2 * xmin
         | 
| 154 | 
            +
                                polys_new[idx] = poly
         | 
| 155 | 
            +
                        elif select_type == 1:
         | 
| 156 | 
            +
                            for idx, polygon in enumerate(polys_new):
         | 
| 157 | 
            +
                                poly = polygon.reshape(-1, 2)
         | 
| 158 | 
            +
                                poly[:, 1] = height - poly[:, 1] + 2 * ymin
         | 
| 159 | 
            +
                                polys_new[idx] = poly
         | 
| 160 | 
            +
                        else:
         | 
| 161 | 
            +
                            for idx, polygon in enumerate(polys_new):
         | 
| 162 | 
            +
                                poly = polygon.reshape(-1, 2)
         | 
| 163 | 
            +
                                poly[:, 0] = width - poly[:, 0] + 2 * xmin
         | 
| 164 | 
            +
                                poly[:, 1] = height - poly[:, 1] + 2 * ymin
         | 
| 165 | 
            +
                                polys_new[idx] = poly
         | 
| 166 | 
            +
                        polygons = polys_keep + polys_new
         | 
| 167 | 
            +
                        ignore_tags = ignore_tags_keep + ignore_tags_new
         | 
| 168 | 
            +
                        results["polys"] = np.array(polygons)
         | 
| 169 | 
            +
                        results["ignore_tags"] = ignore_tags
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    return results
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def generate_crop_target(self, image, all_polys, pad_h, pad_w):
         | 
| 174 | 
            +
                    """Generate crop target and make sure not to crop the polygon
         | 
| 175 | 
            +
                    instances.
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    Args:
         | 
| 178 | 
            +
                        image (ndarray): The image waited to be crop.
         | 
| 179 | 
            +
                        all_polys (list[list[ndarray]]): All polygons including ground
         | 
| 180 | 
            +
                            truth polygons and ground truth ignored polygons.
         | 
| 181 | 
            +
                        pad_h (int): Padding length of height.
         | 
| 182 | 
            +
                        pad_w (int): Padding length of width.
         | 
| 183 | 
            +
                    Returns:
         | 
| 184 | 
            +
                        h_axis (ndarray): Vertical cropping range.
         | 
| 185 | 
            +
                        w_axis (ndarray): Horizontal cropping range.
         | 
| 186 | 
            +
                    """
         | 
| 187 | 
            +
                    h, w, _ = image.shape
         | 
| 188 | 
            +
                    h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
         | 
| 189 | 
            +
                    w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    text_polys = []
         | 
| 192 | 
            +
                    for polygon in all_polys:
         | 
| 193 | 
            +
                        rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
         | 
| 194 | 
            +
                        box = cv2.boxPoints(rect)
         | 
| 195 | 
            +
                        box = np.int0(box)
         | 
| 196 | 
            +
                        text_polys.append([box[0], box[1], box[2], box[3]])
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    polys = np.array(text_polys, dtype=np.int32)
         | 
| 199 | 
            +
                    for poly in polys:
         | 
| 200 | 
            +
                        poly = np.round(poly, decimals=0).astype(np.int32)
         | 
| 201 | 
            +
                        minx = np.min(poly[:, 0])
         | 
| 202 | 
            +
                        maxx = np.max(poly[:, 0])
         | 
| 203 | 
            +
                        w_array[minx + pad_w : maxx + pad_w] = 1
         | 
| 204 | 
            +
                        miny = np.min(poly[:, 1])
         | 
| 205 | 
            +
                        maxy = np.max(poly[:, 1])
         | 
| 206 | 
            +
                        h_array[miny + pad_h : maxy + pad_h] = 1
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    h_axis = np.where(h_array == 0)[0]
         | 
| 209 | 
            +
                    w_axis = np.where(w_array == 0)[0]
         | 
| 210 | 
            +
                    return h_axis, w_axis
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            class RandomCropPolyInstances:
         | 
| 214 | 
            +
                """Randomly crop images and make sure to contain at least one intact
         | 
| 215 | 
            +
                instance."""
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
         | 
| 218 | 
            +
                    super().__init__()
         | 
| 219 | 
            +
                    self.crop_ratio = crop_ratio
         | 
| 220 | 
            +
                    self.min_side_ratio = min_side_ratio
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    assert isinstance(min_len, int)
         | 
| 225 | 
            +
                    assert len(valid_array) > min_len
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    start_array = valid_array.copy()
         | 
| 228 | 
            +
                    max_start = min(len(start_array) - min_len, max_start)
         | 
| 229 | 
            +
                    start_array[max_start:] = 0
         | 
| 230 | 
            +
                    start_array[0] = 1
         | 
| 231 | 
            +
                    diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
         | 
| 232 | 
            +
                    region_starts = np.where(diff_array < 0)[0]
         | 
| 233 | 
            +
                    region_ends = np.where(diff_array > 0)[0]
         | 
| 234 | 
            +
                    region_ind = np.random.randint(0, len(region_starts))
         | 
| 235 | 
            +
                    start = np.random.randint(region_starts[region_ind], region_ends[region_ind])
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    end_array = valid_array.copy()
         | 
| 238 | 
            +
                    min_end = max(start + min_len, min_end)
         | 
| 239 | 
            +
                    end_array[:min_end] = 0
         | 
| 240 | 
            +
                    end_array[-1] = 1
         | 
| 241 | 
            +
                    diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
         | 
| 242 | 
            +
                    region_starts = np.where(diff_array < 0)[0]
         | 
| 243 | 
            +
                    region_ends = np.where(diff_array > 0)[0]
         | 
| 244 | 
            +
                    region_ind = np.random.randint(0, len(region_starts))
         | 
| 245 | 
            +
                    end = np.random.randint(region_starts[region_ind], region_ends[region_ind])
         | 
| 246 | 
            +
                    return start, end
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def sample_crop_box(self, img_size, results):
         | 
| 249 | 
            +
                    """Generate crop box and make sure not to crop the polygon instances.
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    Args:
         | 
| 252 | 
            +
                        img_size (tuple(int)): The image size (h, w).
         | 
| 253 | 
            +
                        results (dict): The results dict.
         | 
| 254 | 
            +
                    """
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    assert isinstance(img_size, tuple)
         | 
| 257 | 
            +
                    h, w = img_size[:2]
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    key_masks = results["polys"]
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    x_valid_array = np.ones(w, dtype=np.int32)
         | 
| 262 | 
            +
                    y_valid_array = np.ones(h, dtype=np.int32)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    selected_mask = key_masks[np.random.randint(0, len(key_masks))]
         | 
| 265 | 
            +
                    selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
         | 
| 266 | 
            +
                    max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
         | 
| 267 | 
            +
                    min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
         | 
| 268 | 
            +
                    max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
         | 
| 269 | 
            +
                    min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    for mask in key_masks:
         | 
| 272 | 
            +
                        mask = mask.reshape((-1, 2)).astype(np.int32)
         | 
| 273 | 
            +
                        clip_x = np.clip(mask[:, 0], 0, w - 1)
         | 
| 274 | 
            +
                        clip_y = np.clip(mask[:, 1], 0, h - 1)
         | 
| 275 | 
            +
                        min_x, max_x = np.min(clip_x), np.max(clip_x)
         | 
| 276 | 
            +
                        min_y, max_y = np.min(clip_y), np.max(clip_y)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                        x_valid_array[min_x - 2 : max_x + 3] = 0
         | 
| 279 | 
            +
                        y_valid_array[min_y - 2 : max_y + 3] = 0
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    min_w = int(w * self.min_side_ratio)
         | 
| 282 | 
            +
                    min_h = int(h * self.min_side_ratio)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    x1, x2 = self.sample_valid_start_end(
         | 
| 285 | 
            +
                        x_valid_array, min_w, max_x_start, min_x_end
         | 
| 286 | 
            +
                    )
         | 
| 287 | 
            +
                    y1, y2 = self.sample_valid_start_end(
         | 
| 288 | 
            +
                        y_valid_array, min_h, max_y_start, min_y_end
         | 
| 289 | 
            +
                    )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    return np.array([x1, y1, x2, y2])
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def crop_img(self, img, bbox):
         | 
| 294 | 
            +
                    assert img.ndim == 3
         | 
| 295 | 
            +
                    h, w, _ = img.shape
         | 
| 296 | 
            +
                    assert 0 <= bbox[1] < bbox[3] <= h
         | 
| 297 | 
            +
                    assert 0 <= bbox[0] < bbox[2] <= w
         | 
| 298 | 
            +
                    return img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def __call__(self, results):
         | 
| 301 | 
            +
                    image = results["image"]
         | 
| 302 | 
            +
                    polygons = results["polys"]
         | 
| 303 | 
            +
                    ignore_tags = results["ignore_tags"]
         | 
| 304 | 
            +
                    if len(polygons) < 1:
         | 
| 305 | 
            +
                        return results
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if np.random.random_sample() < self.crop_ratio:
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                        crop_box = self.sample_crop_box(image.shape, results)
         | 
| 310 | 
            +
                        img = self.crop_img(image, crop_box)
         | 
| 311 | 
            +
                        results["image"] = img
         | 
| 312 | 
            +
                        # crop and filter masks
         | 
| 313 | 
            +
                        x1, y1, x2, y2 = crop_box
         | 
| 314 | 
            +
                        w = max(x2 - x1, 1)
         | 
| 315 | 
            +
                        h = max(y2 - y1, 1)
         | 
| 316 | 
            +
                        polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
         | 
| 317 | 
            +
                        polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                        valid_masks_list = []
         | 
| 320 | 
            +
                        valid_tags_list = []
         | 
| 321 | 
            +
                        for ind, polygon in enumerate(polygons):
         | 
| 322 | 
            +
                            if (
         | 
| 323 | 
            +
                                (polygon[:, ::2] > -4).all()
         | 
| 324 | 
            +
                                and (polygon[:, ::2] < w + 4).all()
         | 
| 325 | 
            +
                                and (polygon[:, 1::2] > -4).all()
         | 
| 326 | 
            +
                                and (polygon[:, 1::2] < h + 4).all()
         | 
| 327 | 
            +
                            ):
         | 
| 328 | 
            +
                                polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
         | 
| 329 | 
            +
                                polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
         | 
| 330 | 
            +
                                valid_masks_list.append(polygon)
         | 
| 331 | 
            +
                                valid_tags_list.append(ignore_tags[ind])
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                        results["polys"] = np.array(valid_masks_list)
         | 
| 334 | 
            +
                        results["ignore_tags"] = valid_tags_list
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    return results
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def __repr__(self):
         | 
| 339 | 
            +
                    repr_str = self.__class__.__name__
         | 
| 340 | 
            +
                    return repr_str
         | 
| 341 | 
            +
             | 
| 342 | 
            +
             | 
| 343 | 
            +
            class RandomRotatePolyInstances:
         | 
| 344 | 
            +
                def __init__(
         | 
| 345 | 
            +
                    self,
         | 
| 346 | 
            +
                    rotate_ratio=0.5,
         | 
| 347 | 
            +
                    max_angle=10,
         | 
| 348 | 
            +
                    pad_with_fixed_color=False,
         | 
| 349 | 
            +
                    pad_value=(0, 0, 0),
         | 
| 350 | 
            +
                    **kwargs
         | 
| 351 | 
            +
                ):
         | 
| 352 | 
            +
                    """Randomly rotate images and polygon masks.
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    Args:
         | 
| 355 | 
            +
                        rotate_ratio (float): The ratio of samples to operate rotation.
         | 
| 356 | 
            +
                        max_angle (int): The maximum rotation angle.
         | 
| 357 | 
            +
                        pad_with_fixed_color (bool): The flag for whether to pad rotated
         | 
| 358 | 
            +
                           image with fixed value. If set to False, the rotated image will
         | 
| 359 | 
            +
                           be padded onto cropped image.
         | 
| 360 | 
            +
                        pad_value (tuple(int)): The color value for padding rotated image.
         | 
| 361 | 
            +
                    """
         | 
| 362 | 
            +
                    self.rotate_ratio = rotate_ratio
         | 
| 363 | 
            +
                    self.max_angle = max_angle
         | 
| 364 | 
            +
                    self.pad_with_fixed_color = pad_with_fixed_color
         | 
| 365 | 
            +
                    self.pad_value = pad_value
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                def rotate(self, center, points, theta, center_shift=(0, 0)):
         | 
| 368 | 
            +
                    # rotate points.
         | 
| 369 | 
            +
                    (center_x, center_y) = center
         | 
| 370 | 
            +
                    center_y = -center_y
         | 
| 371 | 
            +
                    x, y = points[:, ::2], points[:, 1::2]
         | 
| 372 | 
            +
                    y = -y
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    theta = theta / 180 * math.pi
         | 
| 375 | 
            +
                    cos = math.cos(theta)
         | 
| 376 | 
            +
                    sin = math.sin(theta)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    x = x - center_x
         | 
| 379 | 
            +
                    y = y - center_y
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    _x = center_x + x * cos - y * sin + center_shift[0]
         | 
| 382 | 
            +
                    _y = -(center_y + x * sin + y * cos) + center_shift[1]
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    points[:, ::2], points[:, 1::2] = _x, _y
         | 
| 385 | 
            +
                    return points
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def cal_canvas_size(self, ori_size, degree):
         | 
| 388 | 
            +
                    assert isinstance(ori_size, tuple)
         | 
| 389 | 
            +
                    angle = degree * math.pi / 180.0
         | 
| 390 | 
            +
                    h, w = ori_size[:2]
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    cos = math.cos(angle)
         | 
| 393 | 
            +
                    sin = math.sin(angle)
         | 
| 394 | 
            +
                    canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
         | 
| 395 | 
            +
                    canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    canvas_size = (canvas_h, canvas_w)
         | 
| 398 | 
            +
                    return canvas_size
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                def sample_angle(self, max_angle):
         | 
| 401 | 
            +
                    angle = np.random.random_sample() * 2 * max_angle - max_angle
         | 
| 402 | 
            +
                    return angle
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def rotate_img(self, img, angle, canvas_size):
         | 
| 405 | 
            +
                    h, w = img.shape[:2]
         | 
| 406 | 
            +
                    rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
         | 
| 407 | 
            +
                    rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
         | 
| 408 | 
            +
                    rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    if self.pad_with_fixed_color:
         | 
| 411 | 
            +
                        target_img = cv2.warpAffine(
         | 
| 412 | 
            +
                            img,
         | 
| 413 | 
            +
                            rotation_matrix,
         | 
| 414 | 
            +
                            (canvas_size[1], canvas_size[0]),
         | 
| 415 | 
            +
                            flags=cv2.INTER_NEAREST,
         | 
| 416 | 
            +
                            borderValue=self.pad_value,
         | 
| 417 | 
            +
                        )
         | 
| 418 | 
            +
                    else:
         | 
| 419 | 
            +
                        mask = np.zeros_like(img)
         | 
| 420 | 
            +
                        (h_ind, w_ind) = (
         | 
| 421 | 
            +
                            np.random.randint(0, h * 7 // 8),
         | 
| 422 | 
            +
                            np.random.randint(0, w * 7 // 8),
         | 
| 423 | 
            +
                        )
         | 
| 424 | 
            +
                        img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
         | 
| 425 | 
            +
                        img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                        mask = cv2.warpAffine(
         | 
| 428 | 
            +
                            mask,
         | 
| 429 | 
            +
                            rotation_matrix,
         | 
| 430 | 
            +
                            (canvas_size[1], canvas_size[0]),
         | 
| 431 | 
            +
                            borderValue=[1, 1, 1],
         | 
| 432 | 
            +
                        )
         | 
| 433 | 
            +
                        target_img = cv2.warpAffine(
         | 
| 434 | 
            +
                            img,
         | 
| 435 | 
            +
                            rotation_matrix,
         | 
| 436 | 
            +
                            (canvas_size[1], canvas_size[0]),
         | 
| 437 | 
            +
                            borderValue=[0, 0, 0],
         | 
| 438 | 
            +
                        )
         | 
| 439 | 
            +
                        target_img = target_img + img_cut * mask
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    return target_img
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                def __call__(self, results):
         | 
| 444 | 
            +
                    if np.random.random_sample() < self.rotate_ratio:
         | 
| 445 | 
            +
                        image = results["image"]
         | 
| 446 | 
            +
                        polygons = results["polys"]
         | 
| 447 | 
            +
                        h, w = image.shape[:2]
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                        angle = self.sample_angle(self.max_angle)
         | 
| 450 | 
            +
                        canvas_size = self.cal_canvas_size((h, w), angle)
         | 
| 451 | 
            +
                        center_shift = (
         | 
| 452 | 
            +
                            int((canvas_size[1] - w) / 2),
         | 
| 453 | 
            +
                            int((canvas_size[0] - h) / 2),
         | 
| 454 | 
            +
                        )
         | 
| 455 | 
            +
                        image = self.rotate_img(image, angle, canvas_size)
         | 
| 456 | 
            +
                        results["image"] = image
         | 
| 457 | 
            +
                        # rotate polygons
         | 
| 458 | 
            +
                        rotated_masks = []
         | 
| 459 | 
            +
                        for mask in polygons:
         | 
| 460 | 
            +
                            rotated_mask = self.rotate((w / 2, h / 2), mask, angle, center_shift)
         | 
| 461 | 
            +
                            rotated_masks.append(rotated_mask)
         | 
| 462 | 
            +
                        results["polys"] = np.array(rotated_masks)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    return results
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                def __repr__(self):
         | 
| 467 | 
            +
                    repr_str = self.__class__.__name__
         | 
| 468 | 
            +
                    return repr_str
         | 
| 469 | 
            +
             | 
| 470 | 
            +
             | 
| 471 | 
            +
            class SquareResizePad:
         | 
| 472 | 
            +
                def __init__(
         | 
| 473 | 
            +
                    self,
         | 
| 474 | 
            +
                    target_size,
         | 
| 475 | 
            +
                    pad_ratio=0.6,
         | 
| 476 | 
            +
                    pad_with_fixed_color=False,
         | 
| 477 | 
            +
                    pad_value=(0, 0, 0),
         | 
| 478 | 
            +
                    **kwargs
         | 
| 479 | 
            +
                ):
         | 
| 480 | 
            +
                    """Resize or pad images to be square shape.
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    Args:
         | 
| 483 | 
            +
                        target_size (int): The target size of square shaped image.
         | 
| 484 | 
            +
                        pad_with_fixed_color (bool): The flag for whether to pad rotated
         | 
| 485 | 
            +
                           image with fixed value. If set to False, the rescales image will
         | 
| 486 | 
            +
                           be padded onto cropped image.
         | 
| 487 | 
            +
                        pad_value (tuple(int)): The color value for padding rotated image.
         | 
| 488 | 
            +
                    """
         | 
| 489 | 
            +
                    assert isinstance(target_size, int)
         | 
| 490 | 
            +
                    assert isinstance(pad_ratio, float)
         | 
| 491 | 
            +
                    assert isinstance(pad_with_fixed_color, bool)
         | 
| 492 | 
            +
                    assert isinstance(pad_value, tuple)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    self.target_size = target_size
         | 
| 495 | 
            +
                    self.pad_ratio = pad_ratio
         | 
| 496 | 
            +
                    self.pad_with_fixed_color = pad_with_fixed_color
         | 
| 497 | 
            +
                    self.pad_value = pad_value
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                def resize_img(self, img, keep_ratio=True):
         | 
| 500 | 
            +
                    h, w, _ = img.shape
         | 
| 501 | 
            +
                    if keep_ratio:
         | 
| 502 | 
            +
                        t_h = self.target_size if h >= w else int(h * self.target_size / w)
         | 
| 503 | 
            +
                        t_w = self.target_size if h <= w else int(w * self.target_size / h)
         | 
| 504 | 
            +
                    else:
         | 
| 505 | 
            +
                        t_h = t_w = self.target_size
         | 
| 506 | 
            +
                    img = cv2.resize(img, (t_w, t_h))
         | 
| 507 | 
            +
                    return img, (t_h, t_w)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                def square_pad(self, img):
         | 
| 510 | 
            +
                    h, w = img.shape[:2]
         | 
| 511 | 
            +
                    if h == w:
         | 
| 512 | 
            +
                        return img, (0, 0)
         | 
| 513 | 
            +
                    pad_size = max(h, w)
         | 
| 514 | 
            +
                    if self.pad_with_fixed_color:
         | 
| 515 | 
            +
                        expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
         | 
| 516 | 
            +
                        expand_img[:] = self.pad_value
         | 
| 517 | 
            +
                    else:
         | 
| 518 | 
            +
                        (h_ind, w_ind) = (
         | 
| 519 | 
            +
                            np.random.randint(0, h * 7 // 8),
         | 
| 520 | 
            +
                            np.random.randint(0, w * 7 // 8),
         | 
| 521 | 
            +
                        )
         | 
| 522 | 
            +
                        img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
         | 
| 523 | 
            +
                        expand_img = cv2.resize(img_cut, (pad_size, pad_size))
         | 
| 524 | 
            +
                    if h > w:
         | 
| 525 | 
            +
                        y0, x0 = 0, (h - w) // 2
         | 
| 526 | 
            +
                    else:
         | 
| 527 | 
            +
                        y0, x0 = (w - h) // 2, 0
         | 
| 528 | 
            +
                    expand_img[y0 : y0 + h, x0 : x0 + w] = img
         | 
| 529 | 
            +
                    offset = (x0, y0)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    return expand_img, offset
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                def square_pad_mask(self, points, offset):
         | 
| 534 | 
            +
                    x0, y0 = offset
         | 
| 535 | 
            +
                    pad_points = points.copy()
         | 
| 536 | 
            +
                    pad_points[::2] = pad_points[::2] + x0
         | 
| 537 | 
            +
                    pad_points[1::2] = pad_points[1::2] + y0
         | 
| 538 | 
            +
                    return pad_points
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                def __call__(self, results):
         | 
| 541 | 
            +
                    image = results["image"]
         | 
| 542 | 
            +
                    polygons = results["polys"]
         | 
| 543 | 
            +
                    h, w = image.shape[:2]
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    if np.random.random_sample() < self.pad_ratio:
         | 
| 546 | 
            +
                        image, out_size = self.resize_img(image, keep_ratio=True)
         | 
| 547 | 
            +
                        image, offset = self.square_pad(image)
         | 
| 548 | 
            +
                    else:
         | 
| 549 | 
            +
                        image, out_size = self.resize_img(image, keep_ratio=False)
         | 
| 550 | 
            +
                        offset = (0, 0)
         | 
| 551 | 
            +
                    results["image"] = image
         | 
| 552 | 
            +
                    try:
         | 
| 553 | 
            +
                        polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[0]
         | 
| 554 | 
            +
                        polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[1]
         | 
| 555 | 
            +
                    except:
         | 
| 556 | 
            +
                        pass
         | 
| 557 | 
            +
                    results["polys"] = polygons
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    return results
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                def __repr__(self):
         | 
| 562 | 
            +
                    repr_str = self.__class__.__name__
         | 
| 563 | 
            +
                    return repr_str
         | 
    	
        ocr/ppocr/data/imaug/fce_targets.py
    ADDED
    
    | @@ -0,0 +1,671 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from numpy.fft import fft
         | 
| 4 | 
            +
            from numpy.linalg import norm
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def vector_slope(vec):
         | 
| 8 | 
            +
                assert len(vec) == 2
         | 
| 9 | 
            +
                return abs(vec[1] / (vec[0] + 1e-8))
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class FCENetTargets:
         | 
| 13 | 
            +
                """Generate the ground truth targets of FCENet: Fourier Contour Embedding
         | 
| 14 | 
            +
                for Arbitrary-Shaped Text Detection.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                [https://arxiv.org/abs/2104.10442]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                    fourier_degree (int): The maximum Fourier transform degree k.
         | 
| 20 | 
            +
                    resample_step (float): The step size for resampling the text center
         | 
| 21 | 
            +
                        line (TCL). It's better not to exceed half of the minimum width.
         | 
| 22 | 
            +
                    center_region_shrink_ratio (float): The shrink ratio of text center
         | 
| 23 | 
            +
                        region.
         | 
| 24 | 
            +
                    level_size_divisors (tuple(int)): The downsample ratio on each level.
         | 
| 25 | 
            +
                    level_proportion_range (tuple(tuple(int))): The range of text sizes
         | 
| 26 | 
            +
                        assigned to each level.
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __init__(
         | 
| 30 | 
            +
                    self,
         | 
| 31 | 
            +
                    fourier_degree=5,
         | 
| 32 | 
            +
                    resample_step=4.0,
         | 
| 33 | 
            +
                    center_region_shrink_ratio=0.3,
         | 
| 34 | 
            +
                    level_size_divisors=(8, 16, 32),
         | 
| 35 | 
            +
                    level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
         | 
| 36 | 
            +
                    orientation_thr=2.0,
         | 
| 37 | 
            +
                    **kwargs
         | 
| 38 | 
            +
                ):
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    super().__init__()
         | 
| 41 | 
            +
                    assert isinstance(level_size_divisors, tuple)
         | 
| 42 | 
            +
                    assert isinstance(level_proportion_range, tuple)
         | 
| 43 | 
            +
                    assert len(level_size_divisors) == len(level_proportion_range)
         | 
| 44 | 
            +
                    self.fourier_degree = fourier_degree
         | 
| 45 | 
            +
                    self.resample_step = resample_step
         | 
| 46 | 
            +
                    self.center_region_shrink_ratio = center_region_shrink_ratio
         | 
| 47 | 
            +
                    self.level_size_divisors = level_size_divisors
         | 
| 48 | 
            +
                    self.level_proportion_range = level_proportion_range
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.orientation_thr = orientation_thr
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def vector_angle(self, vec1, vec2):
         | 
| 53 | 
            +
                    if vec1.ndim > 1:
         | 
| 54 | 
            +
                        unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
         | 
| 57 | 
            +
                    if vec2.ndim > 1:
         | 
| 58 | 
            +
                        unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
         | 
| 59 | 
            +
                    else:
         | 
| 60 | 
            +
                        unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
         | 
| 61 | 
            +
                    return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def resample_line(self, line, n):
         | 
| 64 | 
            +
                    """Resample n points on a line.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    Args:
         | 
| 67 | 
            +
                        line (ndarray): The points composing a line.
         | 
| 68 | 
            +
                        n (int): The resampled points number.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    Returns:
         | 
| 71 | 
            +
                        resampled_line (ndarray): The points composing the resampled line.
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    assert line.ndim == 2
         | 
| 75 | 
            +
                    assert line.shape[0] >= 2
         | 
| 76 | 
            +
                    assert line.shape[1] == 2
         | 
| 77 | 
            +
                    assert isinstance(n, int)
         | 
| 78 | 
            +
                    assert n > 0
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
         | 
| 81 | 
            +
                    total_length = sum(length_list)
         | 
| 82 | 
            +
                    length_cumsum = np.cumsum([0.0] + length_list)
         | 
| 83 | 
            +
                    delta_length = total_length / (float(n) + 1e-8)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    current_edge_ind = 0
         | 
| 86 | 
            +
                    resampled_line = [line[0]]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    for i in range(1, n):
         | 
| 89 | 
            +
                        current_line_len = i * delta_length
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        while current_line_len >= length_cumsum[current_edge_ind + 1]:
         | 
| 92 | 
            +
                            current_edge_ind += 1
         | 
| 93 | 
            +
                        current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
         | 
| 94 | 
            +
                        end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
         | 
| 95 | 
            +
                        current_point = (
         | 
| 96 | 
            +
                            line[current_edge_ind]
         | 
| 97 | 
            +
                            + (line[current_edge_ind + 1] - line[current_edge_ind])
         | 
| 98 | 
            +
                            * end_shift_ratio
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
                        resampled_line.append(current_point)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    resampled_line.append(line[-1])
         | 
| 103 | 
            +
                    resampled_line = np.array(resampled_line)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    return resampled_line
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def reorder_poly_edge(self, points):
         | 
| 108 | 
            +
                    """Get the respective points composing head edge, tail edge, top
         | 
| 109 | 
            +
                    sideline and bottom sideline.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    Args:
         | 
| 112 | 
            +
                        points (ndarray): The points composing a text polygon.
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    Returns:
         | 
| 115 | 
            +
                        head_edge (ndarray): The two points composing the head edge of text
         | 
| 116 | 
            +
                            polygon.
         | 
| 117 | 
            +
                        tail_edge (ndarray): The two points composing the tail edge of text
         | 
| 118 | 
            +
                            polygon.
         | 
| 119 | 
            +
                        top_sideline (ndarray): The points composing top curved sideline of
         | 
| 120 | 
            +
                            text polygon.
         | 
| 121 | 
            +
                        bot_sideline (ndarray): The points composing bottom curved sideline
         | 
| 122 | 
            +
                            of text polygon.
         | 
| 123 | 
            +
                    """
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    assert points.ndim == 2
         | 
| 126 | 
            +
                    assert points.shape[0] >= 4
         | 
| 127 | 
            +
                    assert points.shape[1] == 2
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
         | 
| 130 | 
            +
                    head_edge, tail_edge = points[head_inds], points[tail_inds]
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    pad_points = np.vstack([points, points])
         | 
| 133 | 
            +
                    if tail_inds[1] < 1:
         | 
| 134 | 
            +
                        tail_inds[1] = len(points)
         | 
| 135 | 
            +
                    sideline1 = pad_points[head_inds[1] : tail_inds[1]]
         | 
| 136 | 
            +
                    sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
         | 
| 137 | 
            +
                    sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if sideline_mean_shift[1] > 0:
         | 
| 140 | 
            +
                        top_sideline, bot_sideline = sideline2, sideline1
         | 
| 141 | 
            +
                    else:
         | 
| 142 | 
            +
                        top_sideline, bot_sideline = sideline1, sideline2
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    return head_edge, tail_edge, top_sideline, bot_sideline
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def find_head_tail(self, points, orientation_thr):
         | 
| 147 | 
            +
                    """Find the head edge and tail edge of a text polygon.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    Args:
         | 
| 150 | 
            +
                        points (ndarray): The points composing a text polygon.
         | 
| 151 | 
            +
                        orientation_thr (float): The threshold for distinguishing between
         | 
| 152 | 
            +
                            head edge and tail edge among the horizontal and vertical edges
         | 
| 153 | 
            +
                            of a quadrangle.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    Returns:
         | 
| 156 | 
            +
                        head_inds (list): The indexes of two points composing head edge.
         | 
| 157 | 
            +
                        tail_inds (list): The indexes of two points composing tail edge.
         | 
| 158 | 
            +
                    """
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    assert points.ndim == 2
         | 
| 161 | 
            +
                    assert points.shape[0] >= 4
         | 
| 162 | 
            +
                    assert points.shape[1] == 2
         | 
| 163 | 
            +
                    assert isinstance(orientation_thr, float)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    if len(points) > 4:
         | 
| 166 | 
            +
                        pad_points = np.vstack([points, points[0]])
         | 
| 167 | 
            +
                        edge_vec = pad_points[1:] - pad_points[:-1]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                        theta_sum = []
         | 
| 170 | 
            +
                        adjacent_vec_theta = []
         | 
| 171 | 
            +
                        for i, edge_vec1 in enumerate(edge_vec):
         | 
| 172 | 
            +
                            adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
         | 
| 173 | 
            +
                            adjacent_edge_vec = edge_vec[adjacent_ind]
         | 
| 174 | 
            +
                            temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
         | 
| 175 | 
            +
                            temp_adjacent_theta = self.vector_angle(
         | 
| 176 | 
            +
                                adjacent_edge_vec[0], adjacent_edge_vec[1]
         | 
| 177 | 
            +
                            )
         | 
| 178 | 
            +
                            theta_sum.append(temp_theta_sum)
         | 
| 179 | 
            +
                            adjacent_vec_theta.append(temp_adjacent_theta)
         | 
| 180 | 
            +
                        theta_sum_score = np.array(theta_sum) / np.pi
         | 
| 181 | 
            +
                        adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
         | 
| 182 | 
            +
                        poly_center = np.mean(points, axis=0)
         | 
| 183 | 
            +
                        edge_dist = np.maximum(
         | 
| 184 | 
            +
                            norm(pad_points[1:] - poly_center, axis=-1),
         | 
| 185 | 
            +
                            norm(pad_points[:-1] - poly_center, axis=-1),
         | 
| 186 | 
            +
                        )
         | 
| 187 | 
            +
                        dist_score = edge_dist / np.max(edge_dist)
         | 
| 188 | 
            +
                        position_score = np.zeros(len(edge_vec))
         | 
| 189 | 
            +
                        score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
         | 
| 190 | 
            +
                        score += 0.35 * dist_score
         | 
| 191 | 
            +
                        if len(points) % 2 == 0:
         | 
| 192 | 
            +
                            position_score[(len(score) // 2 - 1)] += 1
         | 
| 193 | 
            +
                            position_score[-1] += 1
         | 
| 194 | 
            +
                        score += 0.1 * position_score
         | 
| 195 | 
            +
                        pad_score = np.concatenate([score, score])
         | 
| 196 | 
            +
                        score_matrix = np.zeros((len(score), len(score) - 3))
         | 
| 197 | 
            +
                        x = np.arange(len(score) - 3) / float(len(score) - 4)
         | 
| 198 | 
            +
                        gaussian = (
         | 
| 199 | 
            +
                            1.0
         | 
| 200 | 
            +
                            / (np.sqrt(2.0 * np.pi) * 0.5)
         | 
| 201 | 
            +
                            * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
         | 
| 202 | 
            +
                        )
         | 
| 203 | 
            +
                        gaussian = gaussian / np.max(gaussian)
         | 
| 204 | 
            +
                        for i in range(len(score)):
         | 
| 205 | 
            +
                            score_matrix[i, :] = (
         | 
| 206 | 
            +
                                score[i]
         | 
| 207 | 
            +
                                + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
         | 
| 208 | 
            +
                            )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                        head_start, tail_increment = np.unravel_index(
         | 
| 211 | 
            +
                            score_matrix.argmax(), score_matrix.shape
         | 
| 212 | 
            +
                        )
         | 
| 213 | 
            +
                        tail_start = (head_start + tail_increment + 2) % len(points)
         | 
| 214 | 
            +
                        head_end = (head_start + 1) % len(points)
         | 
| 215 | 
            +
                        tail_end = (tail_start + 1) % len(points)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        if head_end > tail_end:
         | 
| 218 | 
            +
                            head_start, tail_start = tail_start, head_start
         | 
| 219 | 
            +
                            head_end, tail_end = tail_end, head_end
         | 
| 220 | 
            +
                        head_inds = [head_start, head_end]
         | 
| 221 | 
            +
                        tail_inds = [tail_start, tail_end]
         | 
| 222 | 
            +
                    else:
         | 
| 223 | 
            +
                        if vector_slope(points[1] - points[0]) + vector_slope(
         | 
| 224 | 
            +
                            points[3] - points[2]
         | 
| 225 | 
            +
                        ) < vector_slope(points[2] - points[1]) + vector_slope(
         | 
| 226 | 
            +
                            points[0] - points[3]
         | 
| 227 | 
            +
                        ):
         | 
| 228 | 
            +
                            horizontal_edge_inds = [[0, 1], [2, 3]]
         | 
| 229 | 
            +
                            vertical_edge_inds = [[3, 0], [1, 2]]
         | 
| 230 | 
            +
                        else:
         | 
| 231 | 
            +
                            horizontal_edge_inds = [[3, 0], [1, 2]]
         | 
| 232 | 
            +
                            vertical_edge_inds = [[0, 1], [2, 3]]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        vertical_len_sum = norm(
         | 
| 235 | 
            +
                            points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
         | 
| 236 | 
            +
                        ) + norm(
         | 
| 237 | 
            +
                            points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
         | 
| 238 | 
            +
                        )
         | 
| 239 | 
            +
                        horizontal_len_sum = norm(
         | 
| 240 | 
            +
                            points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
         | 
| 241 | 
            +
                        ) + norm(
         | 
| 242 | 
            +
                            points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
         | 
| 243 | 
            +
                        )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                        if vertical_len_sum > horizontal_len_sum * orientation_thr:
         | 
| 246 | 
            +
                            head_inds = horizontal_edge_inds[0]
         | 
| 247 | 
            +
                            tail_inds = horizontal_edge_inds[1]
         | 
| 248 | 
            +
                        else:
         | 
| 249 | 
            +
                            head_inds = vertical_edge_inds[0]
         | 
| 250 | 
            +
                            tail_inds = vertical_edge_inds[1]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    return head_inds, tail_inds
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def resample_sidelines(self, sideline1, sideline2, resample_step):
         | 
| 255 | 
            +
                    """Resample two sidelines to be of the same points number according to
         | 
| 256 | 
            +
                    step size.
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    Args:
         | 
| 259 | 
            +
                        sideline1 (ndarray): The points composing a sideline of a text
         | 
| 260 | 
            +
                            polygon.
         | 
| 261 | 
            +
                        sideline2 (ndarray): The points composing another sideline of a
         | 
| 262 | 
            +
                            text polygon.
         | 
| 263 | 
            +
                        resample_step (float): The resampled step size.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    Returns:
         | 
| 266 | 
            +
                        resampled_line1 (ndarray): The resampled line 1.
         | 
| 267 | 
            +
                        resampled_line2 (ndarray): The resampled line 2.
         | 
| 268 | 
            +
                    """
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    assert sideline1.ndim == sideline2.ndim == 2
         | 
| 271 | 
            +
                    assert sideline1.shape[1] == sideline2.shape[1] == 2
         | 
| 272 | 
            +
                    assert sideline1.shape[0] >= 2
         | 
| 273 | 
            +
                    assert sideline2.shape[0] >= 2
         | 
| 274 | 
            +
                    assert isinstance(resample_step, float)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    length1 = sum(
         | 
| 277 | 
            +
                        [norm(sideline1[i + 1] - sideline1[i]) for i in range(len(sideline1) - 1)]
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    length2 = sum(
         | 
| 280 | 
            +
                        [norm(sideline2[i + 1] - sideline2[i]) for i in range(len(sideline2) - 1)]
         | 
| 281 | 
            +
                    )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    total_length = (length1 + length2) / 2
         | 
| 284 | 
            +
                    resample_point_num = max(int(float(total_length) / resample_step), 1)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    resampled_line1 = self.resample_line(sideline1, resample_point_num)
         | 
| 287 | 
            +
                    resampled_line2 = self.resample_line(sideline2, resample_point_num)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    return resampled_line1, resampled_line2
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                def generate_center_region_mask(self, img_size, text_polys):
         | 
| 292 | 
            +
                    """Generate text center region mask.
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    Args:
         | 
| 295 | 
            +
                        img_size (tuple): The image size of (height, width).
         | 
| 296 | 
            +
                        text_polys (list[list[ndarray]]): The list of text polygons.
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    Returns:
         | 
| 299 | 
            +
                        center_region_mask (ndarray): The text center region mask.
         | 
| 300 | 
            +
                    """
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    assert isinstance(img_size, tuple)
         | 
| 303 | 
            +
                    # assert check_argument.is_2dlist(text_polys)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    h, w = img_size
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    center_region_mask = np.zeros((h, w), np.uint8)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    center_region_boxes = []
         | 
| 310 | 
            +
                    for poly in text_polys:
         | 
| 311 | 
            +
                        # assert len(poly) == 1
         | 
| 312 | 
            +
                        polygon_points = poly.reshape(-1, 2)
         | 
| 313 | 
            +
                        _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
         | 
| 314 | 
            +
                        resampled_top_line, resampled_bot_line = self.resample_sidelines(
         | 
| 315 | 
            +
                            top_line, bot_line, self.resample_step
         | 
| 316 | 
            +
                        )
         | 
| 317 | 
            +
                        resampled_bot_line = resampled_bot_line[::-1]
         | 
| 318 | 
            +
                        center_line = (resampled_top_line + resampled_bot_line) / 2
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                        line_head_shrink_len = (
         | 
| 321 | 
            +
                            norm(resampled_top_line[0] - resampled_bot_line[0]) / 4.0
         | 
| 322 | 
            +
                        )
         | 
| 323 | 
            +
                        line_tail_shrink_len = (
         | 
| 324 | 
            +
                            norm(resampled_top_line[-1] - resampled_bot_line[-1]) / 4.0
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
                        head_shrink_num = int(line_head_shrink_len // self.resample_step)
         | 
| 327 | 
            +
                        tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
         | 
| 328 | 
            +
                        if len(center_line) > head_shrink_num + tail_shrink_num + 2:
         | 
| 329 | 
            +
                            center_line = center_line[
         | 
| 330 | 
            +
                                head_shrink_num : len(center_line) - tail_shrink_num
         | 
| 331 | 
            +
                            ]
         | 
| 332 | 
            +
                            resampled_top_line = resampled_top_line[
         | 
| 333 | 
            +
                                head_shrink_num : len(resampled_top_line) - tail_shrink_num
         | 
| 334 | 
            +
                            ]
         | 
| 335 | 
            +
                            resampled_bot_line = resampled_bot_line[
         | 
| 336 | 
            +
                                head_shrink_num : len(resampled_bot_line) - tail_shrink_num
         | 
| 337 | 
            +
                            ]
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        for i in range(0, len(center_line) - 1):
         | 
| 340 | 
            +
                            tl = (
         | 
| 341 | 
            +
                                center_line[i]
         | 
| 342 | 
            +
                                + (resampled_top_line[i] - center_line[i])
         | 
| 343 | 
            +
                                * self.center_region_shrink_ratio
         | 
| 344 | 
            +
                            )
         | 
| 345 | 
            +
                            tr = (
         | 
| 346 | 
            +
                                center_line[i + 1]
         | 
| 347 | 
            +
                                + (resampled_top_line[i + 1] - center_line[i + 1])
         | 
| 348 | 
            +
                                * self.center_region_shrink_ratio
         | 
| 349 | 
            +
                            )
         | 
| 350 | 
            +
                            br = (
         | 
| 351 | 
            +
                                center_line[i + 1]
         | 
| 352 | 
            +
                                + (resampled_bot_line[i + 1] - center_line[i + 1])
         | 
| 353 | 
            +
                                * self.center_region_shrink_ratio
         | 
| 354 | 
            +
                            )
         | 
| 355 | 
            +
                            bl = (
         | 
| 356 | 
            +
                                center_line[i]
         | 
| 357 | 
            +
                                + (resampled_bot_line[i] - center_line[i])
         | 
| 358 | 
            +
                                * self.center_region_shrink_ratio
         | 
| 359 | 
            +
                            )
         | 
| 360 | 
            +
                            current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
         | 
| 361 | 
            +
                            center_region_boxes.append(current_center_box)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    cv2.fillPoly(center_region_mask, center_region_boxes, 1)
         | 
| 364 | 
            +
                    return center_region_mask
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def resample_polygon(self, polygon, n=400):
         | 
| 367 | 
            +
                    """Resample one polygon with n points on its boundary.
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    Args:
         | 
| 370 | 
            +
                        polygon (list[float]): The input polygon.
         | 
| 371 | 
            +
                        n (int): The number of resampled points.
         | 
| 372 | 
            +
                    Returns:
         | 
| 373 | 
            +
                        resampled_polygon (list[float]): The resampled polygon.
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
                    length = []
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    for i in range(len(polygon)):
         | 
| 378 | 
            +
                        p1 = polygon[i]
         | 
| 379 | 
            +
                        if i == len(polygon) - 1:
         | 
| 380 | 
            +
                            p2 = polygon[0]
         | 
| 381 | 
            +
                        else:
         | 
| 382 | 
            +
                            p2 = polygon[i + 1]
         | 
| 383 | 
            +
                        length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    total_length = sum(length)
         | 
| 386 | 
            +
                    n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
         | 
| 387 | 
            +
                    n_on_each_line = n_on_each_line.astype(np.int32)
         | 
| 388 | 
            +
                    new_polygon = []
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    for i in range(len(polygon)):
         | 
| 391 | 
            +
                        num = n_on_each_line[i]
         | 
| 392 | 
            +
                        p1 = polygon[i]
         | 
| 393 | 
            +
                        if i == len(polygon) - 1:
         | 
| 394 | 
            +
                            p2 = polygon[0]
         | 
| 395 | 
            +
                        else:
         | 
| 396 | 
            +
                            p2 = polygon[i + 1]
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                        if num == 0:
         | 
| 399 | 
            +
                            continue
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        dxdy = (p2 - p1) / num
         | 
| 402 | 
            +
                        for j in range(num):
         | 
| 403 | 
            +
                            point = p1 + dxdy * j
         | 
| 404 | 
            +
                            new_polygon.append(point)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    return np.array(new_polygon)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def normalize_polygon(self, polygon):
         | 
| 409 | 
            +
                    """Normalize one polygon so that its start point is at right most.
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    Args:
         | 
| 412 | 
            +
                        polygon (list[float]): The origin polygon.
         | 
| 413 | 
            +
                    Returns:
         | 
| 414 | 
            +
                        new_polygon (lost[float]): The polygon with start point at right.
         | 
| 415 | 
            +
                    """
         | 
| 416 | 
            +
                    temp_polygon = polygon - polygon.mean(axis=0)
         | 
| 417 | 
            +
                    x = np.abs(temp_polygon[:, 0])
         | 
| 418 | 
            +
                    y = temp_polygon[:, 1]
         | 
| 419 | 
            +
                    index_x = np.argsort(x)
         | 
| 420 | 
            +
                    index_y = np.argmin(y[index_x[:8]])
         | 
| 421 | 
            +
                    index = index_x[index_y]
         | 
| 422 | 
            +
                    new_polygon = np.concatenate([polygon[index:], polygon[:index]])
         | 
| 423 | 
            +
                    return new_polygon
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                def poly2fourier(self, polygon, fourier_degree):
         | 
| 426 | 
            +
                    """Perform Fourier transformation to generate Fourier coefficients ck
         | 
| 427 | 
            +
                    from polygon.
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    Args:
         | 
| 430 | 
            +
                        polygon (ndarray): An input polygon.
         | 
| 431 | 
            +
                        fourier_degree (int): The maximum Fourier degree K.
         | 
| 432 | 
            +
                    Returns:
         | 
| 433 | 
            +
                        c (ndarray(complex)): Fourier coefficients.
         | 
| 434 | 
            +
                    """
         | 
| 435 | 
            +
                    points = polygon[:, 0] + polygon[:, 1] * 1j
         | 
| 436 | 
            +
                    c_fft = fft(points) / len(points)
         | 
| 437 | 
            +
                    c = np.hstack((c_fft[-fourier_degree:], c_fft[: fourier_degree + 1]))
         | 
| 438 | 
            +
                    return c
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                def clockwise(self, c, fourier_degree):
         | 
| 441 | 
            +
                    """Make sure the polygon reconstructed from Fourier coefficients c in
         | 
| 442 | 
            +
                    the clockwise direction.
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    Args:
         | 
| 445 | 
            +
                        polygon (list[float]): The origin polygon.
         | 
| 446 | 
            +
                    Returns:
         | 
| 447 | 
            +
                        new_polygon (lost[float]): The polygon in clockwise point order.
         | 
| 448 | 
            +
                    """
         | 
| 449 | 
            +
                    if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
         | 
| 450 | 
            +
                        return c
         | 
| 451 | 
            +
                    elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
         | 
| 452 | 
            +
                        return c[::-1]
         | 
| 453 | 
            +
                    else:
         | 
| 454 | 
            +
                        if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
         | 
| 455 | 
            +
                            return c
         | 
| 456 | 
            +
                        else:
         | 
| 457 | 
            +
                            return c[::-1]
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                def cal_fourier_signature(self, polygon, fourier_degree):
         | 
| 460 | 
            +
                    """Calculate Fourier signature from input polygon.
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    Args:
         | 
| 463 | 
            +
                          polygon (ndarray): The input polygon.
         | 
| 464 | 
            +
                          fourier_degree (int): The maximum Fourier degree K.
         | 
| 465 | 
            +
                    Returns:
         | 
| 466 | 
            +
                          fourier_signature (ndarray): An array shaped (2k+1, 2) containing
         | 
| 467 | 
            +
                              real part and image part of 2k+1 Fourier coefficients.
         | 
| 468 | 
            +
                    """
         | 
| 469 | 
            +
                    resampled_polygon = self.resample_polygon(polygon)
         | 
| 470 | 
            +
                    resampled_polygon = self.normalize_polygon(resampled_polygon)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
         | 
| 473 | 
            +
                    fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    real_part = np.real(fourier_coeff).reshape((-1, 1))
         | 
| 476 | 
            +
                    image_part = np.imag(fourier_coeff).reshape((-1, 1))
         | 
| 477 | 
            +
                    fourier_signature = np.hstack([real_part, image_part])
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    return fourier_signature
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                def generate_fourier_maps(self, img_size, text_polys):
         | 
| 482 | 
            +
                    """Generate Fourier coefficient maps.
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    Args:
         | 
| 485 | 
            +
                        img_size (tuple): The image size of (height, width).
         | 
| 486 | 
            +
                        text_polys (list[list[ndarray]]): The list of text polygons.
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    Returns:
         | 
| 489 | 
            +
                        fourier_real_map (ndarray): The Fourier coefficient real part maps.
         | 
| 490 | 
            +
                        fourier_image_map (ndarray): The Fourier coefficient image part
         | 
| 491 | 
            +
                            maps.
         | 
| 492 | 
            +
                    """
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    assert isinstance(img_size, tuple)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    h, w = img_size
         | 
| 497 | 
            +
                    k = self.fourier_degree
         | 
| 498 | 
            +
                    real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
         | 
| 499 | 
            +
                    imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    for poly in text_polys:
         | 
| 502 | 
            +
                        mask = np.zeros((h, w), dtype=np.uint8)
         | 
| 503 | 
            +
                        polygon = np.array(poly).reshape((1, -1, 2))
         | 
| 504 | 
            +
                        cv2.fillPoly(mask, polygon.astype(np.int32), 1)
         | 
| 505 | 
            +
                        fourier_coeff = self.cal_fourier_signature(polygon[0], k)
         | 
| 506 | 
            +
                        for i in range(-k, k + 1):
         | 
| 507 | 
            +
                            if i != 0:
         | 
| 508 | 
            +
                                real_map[i + k, :, :] = (
         | 
| 509 | 
            +
                                    mask * fourier_coeff[i + k, 0]
         | 
| 510 | 
            +
                                    + (1 - mask) * real_map[i + k, :, :]
         | 
| 511 | 
            +
                                )
         | 
| 512 | 
            +
                                imag_map[i + k, :, :] = (
         | 
| 513 | 
            +
                                    mask * fourier_coeff[i + k, 1]
         | 
| 514 | 
            +
                                    + (1 - mask) * imag_map[i + k, :, :]
         | 
| 515 | 
            +
                                )
         | 
| 516 | 
            +
                            else:
         | 
| 517 | 
            +
                                yx = np.argwhere(mask > 0.5)
         | 
| 518 | 
            +
                                k_ind = np.ones((len(yx)), dtype=np.int64) * k
         | 
| 519 | 
            +
                                y, x = yx[:, 0], yx[:, 1]
         | 
| 520 | 
            +
                                real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
         | 
| 521 | 
            +
                                imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    return real_map, imag_map
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                def generate_text_region_mask(self, img_size, text_polys):
         | 
| 526 | 
            +
                    """Generate text center region mask and geometry attribute maps.
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    Args:
         | 
| 529 | 
            +
                        img_size (tuple): The image size (height, width).
         | 
| 530 | 
            +
                        text_polys (list[list[ndarray]]): The list of text polygons.
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    Returns:
         | 
| 533 | 
            +
                        text_region_mask (ndarray): The text region mask.
         | 
| 534 | 
            +
                    """
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    assert isinstance(img_size, tuple)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    h, w = img_size
         | 
| 539 | 
            +
                    text_region_mask = np.zeros((h, w), dtype=np.uint8)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    for poly in text_polys:
         | 
| 542 | 
            +
                        polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
         | 
| 543 | 
            +
                        cv2.fillPoly(text_region_mask, polygon, 1)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    return text_region_mask
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
         | 
| 548 | 
            +
                    """Generate effective mask by setting the ineffective regions to 0 and
         | 
| 549 | 
            +
                    effective regions to 1.
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                    Args:
         | 
| 552 | 
            +
                        mask_size (tuple): The mask size.
         | 
| 553 | 
            +
                        polygons_ignore (list[[ndarray]]: The list of ignored text
         | 
| 554 | 
            +
                            polygons.
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    Returns:
         | 
| 557 | 
            +
                        mask (ndarray): The effective mask of (height, width).
         | 
| 558 | 
            +
                    """
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    mask = np.ones(mask_size, dtype=np.uint8)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    for poly in polygons_ignore:
         | 
| 563 | 
            +
                        instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
         | 
| 564 | 
            +
                        cv2.fillPoly(mask, instance, 0)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    return mask
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                def generate_level_targets(self, img_size, text_polys, ignore_polys):
         | 
| 569 | 
            +
                    """Generate ground truth target on each level.
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    Args:
         | 
| 572 | 
            +
                        img_size (list[int]): Shape of input image.
         | 
| 573 | 
            +
                        text_polys (list[list[ndarray]]): A list of ground truth polygons.
         | 
| 574 | 
            +
                        ignore_polys (list[list[ndarray]]): A list of ignored polygons.
         | 
| 575 | 
            +
                    Returns:
         | 
| 576 | 
            +
                        level_maps (list(ndarray)): A list of ground target on each level.
         | 
| 577 | 
            +
                    """
         | 
| 578 | 
            +
                    h, w = img_size
         | 
| 579 | 
            +
                    lv_size_divs = self.level_size_divisors
         | 
| 580 | 
            +
                    lv_proportion_range = self.level_proportion_range
         | 
| 581 | 
            +
                    lv_text_polys = [[] for i in range(len(lv_size_divs))]
         | 
| 582 | 
            +
                    lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
         | 
| 583 | 
            +
                    level_maps = []
         | 
| 584 | 
            +
                    for poly in text_polys:
         | 
| 585 | 
            +
                        polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
         | 
| 586 | 
            +
                        _, _, box_w, box_h = cv2.boundingRect(polygon)
         | 
| 587 | 
            +
                        proportion = max(box_h, box_w) / (h + 1e-8)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                        for ind, proportion_range in enumerate(lv_proportion_range):
         | 
| 590 | 
            +
                            if proportion_range[0] < proportion < proportion_range[1]:
         | 
| 591 | 
            +
                                lv_text_polys[ind].append(poly / lv_size_divs[ind])
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    for ignore_poly in ignore_polys:
         | 
| 594 | 
            +
                        polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
         | 
| 595 | 
            +
                        _, _, box_w, box_h = cv2.boundingRect(polygon)
         | 
| 596 | 
            +
                        proportion = max(box_h, box_w) / (h + 1e-8)
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                        for ind, proportion_range in enumerate(lv_proportion_range):
         | 
| 599 | 
            +
                            if proportion_range[0] < proportion < proportion_range[1]:
         | 
| 600 | 
            +
                                lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    for ind, size_divisor in enumerate(lv_size_divs):
         | 
| 603 | 
            +
                        current_level_maps = []
         | 
| 604 | 
            +
                        level_img_size = (h // size_divisor, w // size_divisor)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                        text_region = self.generate_text_region_mask(
         | 
| 607 | 
            +
                            level_img_size, lv_text_polys[ind]
         | 
| 608 | 
            +
                        )[None]
         | 
| 609 | 
            +
                        current_level_maps.append(text_region)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                        center_region = self.generate_center_region_mask(
         | 
| 612 | 
            +
                            level_img_size, lv_text_polys[ind]
         | 
| 613 | 
            +
                        )[None]
         | 
| 614 | 
            +
                        current_level_maps.append(center_region)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                        effective_mask = self.generate_effective_mask(
         | 
| 617 | 
            +
                            level_img_size, lv_ignore_polys[ind]
         | 
| 618 | 
            +
                        )[None]
         | 
| 619 | 
            +
                        current_level_maps.append(effective_mask)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                        fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
         | 
| 622 | 
            +
                            level_img_size, lv_text_polys[ind]
         | 
| 623 | 
            +
                        )
         | 
| 624 | 
            +
                        current_level_maps.append(fourier_real_map)
         | 
| 625 | 
            +
                        current_level_maps.append(fourier_image_maps)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                        level_maps.append(np.concatenate(current_level_maps))
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    return level_maps
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                def generate_targets(self, results):
         | 
| 632 | 
            +
                    """Generate the ground truth targets for FCENet.
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    Args:
         | 
| 635 | 
            +
                        results (dict): The input result dictionary.
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    Returns:
         | 
| 638 | 
            +
                        results (dict): The output result dictionary.
         | 
| 639 | 
            +
                    """
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    assert isinstance(results, dict)
         | 
| 642 | 
            +
                    image = results["image"]
         | 
| 643 | 
            +
                    polygons = results["polys"]
         | 
| 644 | 
            +
                    ignore_tags = results["ignore_tags"]
         | 
| 645 | 
            +
                    h, w, _ = image.shape
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                    polygon_masks = []
         | 
| 648 | 
            +
                    polygon_masks_ignore = []
         | 
| 649 | 
            +
                    for tag, polygon in zip(ignore_tags, polygons):
         | 
| 650 | 
            +
                        if tag is True:
         | 
| 651 | 
            +
                            polygon_masks_ignore.append(polygon)
         | 
| 652 | 
            +
                        else:
         | 
| 653 | 
            +
                            polygon_masks.append(polygon)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    level_maps = self.generate_level_targets(
         | 
| 656 | 
            +
                        (h, w), polygon_masks, polygon_masks_ignore
         | 
| 657 | 
            +
                    )
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    mapping = {
         | 
| 660 | 
            +
                        "p3_maps": level_maps[0],
         | 
| 661 | 
            +
                        "p4_maps": level_maps[1],
         | 
| 662 | 
            +
                        "p5_maps": level_maps[2],
         | 
| 663 | 
            +
                    }
         | 
| 664 | 
            +
                    for key, value in mapping.items():
         | 
| 665 | 
            +
                        results[key] = value
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    return results
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                def __call__(self, results):
         | 
| 670 | 
            +
                    results = self.generate_targets(results)
         | 
| 671 | 
            +
                    return results
         | 
    	
        ocr/ppocr/data/imaug/gen_table_mask.py
    ADDED
    
    | @@ -0,0 +1,228 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class GenTableMask(object):
         | 
| 8 | 
            +
                """gen table mask"""
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
         | 
| 11 | 
            +
                    self.shrink_h_max = 5
         | 
| 12 | 
            +
                    self.shrink_w_max = 5
         | 
| 13 | 
            +
                    self.mask_type = mask_type
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def projection(self, erosion, h, w, spilt_threshold=0):
         | 
| 16 | 
            +
                    # 水平投影
         | 
| 17 | 
            +
                    projection_map = np.ones_like(erosion)
         | 
| 18 | 
            +
                    project_val_array = [0 for _ in range(0, h)]
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    for j in range(0, h):
         | 
| 21 | 
            +
                        for i in range(0, w):
         | 
| 22 | 
            +
                            if erosion[j, i] == 255:
         | 
| 23 | 
            +
                                project_val_array[j] += 1
         | 
| 24 | 
            +
                    # 根据数组,获取切割点
         | 
| 25 | 
            +
                    start_idx = 0  # 记录进入字符区的索引
         | 
| 26 | 
            +
                    end_idx = 0  # 记录进入空白区域的索引
         | 
| 27 | 
            +
                    in_text = False  # 是否遍历到了字符区内
         | 
| 28 | 
            +
                    box_list = []
         | 
| 29 | 
            +
                    for i in range(len(project_val_array)):
         | 
| 30 | 
            +
                        if in_text == False and project_val_array[i] > spilt_threshold:  # 进入字符区了
         | 
| 31 | 
            +
                            in_text = True
         | 
| 32 | 
            +
                            start_idx = i
         | 
| 33 | 
            +
                        elif project_val_array[i] <= spilt_threshold and in_text == True:  # 进入空白区了
         | 
| 34 | 
            +
                            end_idx = i
         | 
| 35 | 
            +
                            in_text = False
         | 
| 36 | 
            +
                            if end_idx - start_idx <= 2:
         | 
| 37 | 
            +
                                continue
         | 
| 38 | 
            +
                            box_list.append((start_idx, end_idx + 1))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    if in_text:
         | 
| 41 | 
            +
                        box_list.append((start_idx, h - 1))
         | 
| 42 | 
            +
                    # 绘制投影直方图
         | 
| 43 | 
            +
                    for j in range(0, h):
         | 
| 44 | 
            +
                        for i in range(0, project_val_array[j]):
         | 
| 45 | 
            +
                            projection_map[j, i] = 0
         | 
| 46 | 
            +
                    return box_list, projection_map
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def projection_cx(self, box_img):
         | 
| 49 | 
            +
                    box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
         | 
| 50 | 
            +
                    h, w = box_gray_img.shape
         | 
| 51 | 
            +
                    # 灰度图片进行二值化处理
         | 
| 52 | 
            +
                    ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
         | 
| 53 | 
            +
                    # 纵向腐蚀
         | 
| 54 | 
            +
                    if h < w:
         | 
| 55 | 
            +
                        kernel = np.ones((2, 1), np.uint8)
         | 
| 56 | 
            +
                        erode = cv2.erode(thresh1, kernel, iterations=1)
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        erode = thresh1
         | 
| 59 | 
            +
                    # 水平膨胀
         | 
| 60 | 
            +
                    kernel = np.ones((1, 5), np.uint8)
         | 
| 61 | 
            +
                    erosion = cv2.dilate(erode, kernel, iterations=1)
         | 
| 62 | 
            +
                    # 水平投影
         | 
| 63 | 
            +
                    projection_map = np.ones_like(erosion)
         | 
| 64 | 
            +
                    project_val_array = [0 for _ in range(0, h)]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    for j in range(0, h):
         | 
| 67 | 
            +
                        for i in range(0, w):
         | 
| 68 | 
            +
                            if erosion[j, i] == 255:
         | 
| 69 | 
            +
                                project_val_array[j] += 1
         | 
| 70 | 
            +
                    # 根据数组,获取切割点
         | 
| 71 | 
            +
                    start_idx = 0  # 记录进入字符区的索引
         | 
| 72 | 
            +
                    end_idx = 0  # 记录进入空白区域的索引
         | 
| 73 | 
            +
                    in_text = False  # 是否遍历到了字符区内
         | 
| 74 | 
            +
                    box_list = []
         | 
| 75 | 
            +
                    spilt_threshold = 0
         | 
| 76 | 
            +
                    for i in range(len(project_val_array)):
         | 
| 77 | 
            +
                        if in_text == False and project_val_array[i] > spilt_threshold:  # 进入字符区了
         | 
| 78 | 
            +
                            in_text = True
         | 
| 79 | 
            +
                            start_idx = i
         | 
| 80 | 
            +
                        elif project_val_array[i] <= spilt_threshold and in_text == True:  # 进入空白区了
         | 
| 81 | 
            +
                            end_idx = i
         | 
| 82 | 
            +
                            in_text = False
         | 
| 83 | 
            +
                            if end_idx - start_idx <= 2:
         | 
| 84 | 
            +
                                continue
         | 
| 85 | 
            +
                            box_list.append((start_idx, end_idx + 1))
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    if in_text:
         | 
| 88 | 
            +
                        box_list.append((start_idx, h - 1))
         | 
| 89 | 
            +
                    # 绘制投影直方图
         | 
| 90 | 
            +
                    for j in range(0, h):
         | 
| 91 | 
            +
                        for i in range(0, project_val_array[j]):
         | 
| 92 | 
            +
                            projection_map[j, i] = 0
         | 
| 93 | 
            +
                    split_bbox_list = []
         | 
| 94 | 
            +
                    if len(box_list) > 1:
         | 
| 95 | 
            +
                        for i, (h_start, h_end) in enumerate(box_list):
         | 
| 96 | 
            +
                            if i == 0:
         | 
| 97 | 
            +
                                h_start = 0
         | 
| 98 | 
            +
                            if i == len(box_list):
         | 
| 99 | 
            +
                                h_end = h
         | 
| 100 | 
            +
                            word_img = erosion[h_start : h_end + 1, :]
         | 
| 101 | 
            +
                            word_h, word_w = word_img.shape
         | 
| 102 | 
            +
                            w_split_list, w_projection_map = self.projection(
         | 
| 103 | 
            +
                                word_img.T, word_w, word_h
         | 
| 104 | 
            +
                            )
         | 
| 105 | 
            +
                            w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
         | 
| 106 | 
            +
                            if h_start > 0:
         | 
| 107 | 
            +
                                h_start -= 1
         | 
| 108 | 
            +
                            h_end += 1
         | 
| 109 | 
            +
                            word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :]
         | 
| 110 | 
            +
                            split_bbox_list.append([w_start, h_start, w_end, h_end])
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        split_bbox_list.append([0, 0, w, h])
         | 
| 113 | 
            +
                    return split_bbox_list
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def shrink_bbox(self, bbox):
         | 
| 116 | 
            +
                    left, top, right, bottom = bbox
         | 
| 117 | 
            +
                    sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
         | 
| 118 | 
            +
                    sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
         | 
| 119 | 
            +
                    left_new = left + sh_w
         | 
| 120 | 
            +
                    right_new = right - sh_w
         | 
| 121 | 
            +
                    top_new = top + sh_h
         | 
| 122 | 
            +
                    bottom_new = bottom - sh_h
         | 
| 123 | 
            +
                    if left_new >= right_new:
         | 
| 124 | 
            +
                        left_new = left
         | 
| 125 | 
            +
                        right_new = right
         | 
| 126 | 
            +
                    if top_new >= bottom_new:
         | 
| 127 | 
            +
                        top_new = top
         | 
| 128 | 
            +
                        bottom_new = bottom
         | 
| 129 | 
            +
                    return [left_new, top_new, right_new, bottom_new]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def __call__(self, data):
         | 
| 132 | 
            +
                    img = data["image"]
         | 
| 133 | 
            +
                    cells = data["cells"]
         | 
| 134 | 
            +
                    height, width = img.shape[0:2]
         | 
| 135 | 
            +
                    if self.mask_type == 1:
         | 
| 136 | 
            +
                        mask_img = np.zeros((height, width), dtype=np.float32)
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        mask_img = np.zeros((height, width, 3), dtype=np.float32)
         | 
| 139 | 
            +
                    cell_num = len(cells)
         | 
| 140 | 
            +
                    for cno in range(cell_num):
         | 
| 141 | 
            +
                        if "bbox" in cells[cno]:
         | 
| 142 | 
            +
                            bbox = cells[cno]["bbox"]
         | 
| 143 | 
            +
                            left, top, right, bottom = bbox
         | 
| 144 | 
            +
                            box_img = img[top:bottom, left:right, :].copy()
         | 
| 145 | 
            +
                            split_bbox_list = self.projection_cx(box_img)
         | 
| 146 | 
            +
                            for sno in range(len(split_bbox_list)):
         | 
| 147 | 
            +
                                split_bbox_list[sno][0] += left
         | 
| 148 | 
            +
                                split_bbox_list[sno][1] += top
         | 
| 149 | 
            +
                                split_bbox_list[sno][2] += left
         | 
| 150 | 
            +
                                split_bbox_list[sno][3] += top
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                            for sno in range(len(split_bbox_list)):
         | 
| 153 | 
            +
                                left, top, right, bottom = split_bbox_list[sno]
         | 
| 154 | 
            +
                                left, top, right, bottom = self.shrink_bbox(
         | 
| 155 | 
            +
                                    [left, top, right, bottom]
         | 
| 156 | 
            +
                                )
         | 
| 157 | 
            +
                                if self.mask_type == 1:
         | 
| 158 | 
            +
                                    mask_img[top:bottom, left:right] = 1.0
         | 
| 159 | 
            +
                                    data["mask_img"] = mask_img
         | 
| 160 | 
            +
                                else:
         | 
| 161 | 
            +
                                    mask_img[top:bottom, left:right, :] = (255, 255, 255)
         | 
| 162 | 
            +
                                    data["image"] = mask_img
         | 
| 163 | 
            +
                    return data
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class ResizeTableImage(object):
         | 
| 167 | 
            +
                def __init__(self, max_len, **kwargs):
         | 
| 168 | 
            +
                    super(ResizeTableImage, self).__init__()
         | 
| 169 | 
            +
                    self.max_len = max_len
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def get_img_bbox(self, cells):
         | 
| 172 | 
            +
                    bbox_list = []
         | 
| 173 | 
            +
                    if len(cells) == 0:
         | 
| 174 | 
            +
                        return bbox_list
         | 
| 175 | 
            +
                    cell_num = len(cells)
         | 
| 176 | 
            +
                    for cno in range(cell_num):
         | 
| 177 | 
            +
                        if "bbox" in cells[cno]:
         | 
| 178 | 
            +
                            bbox = cells[cno]["bbox"]
         | 
| 179 | 
            +
                            bbox_list.append(bbox)
         | 
| 180 | 
            +
                    return bbox_list
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def resize_img_table(self, img, bbox_list, max_len):
         | 
| 183 | 
            +
                    height, width = img.shape[0:2]
         | 
| 184 | 
            +
                    ratio = max_len / (max(height, width) * 1.0)
         | 
| 185 | 
            +
                    resize_h = int(height * ratio)
         | 
| 186 | 
            +
                    resize_w = int(width * ratio)
         | 
| 187 | 
            +
                    img_new = cv2.resize(img, (resize_w, resize_h))
         | 
| 188 | 
            +
                    bbox_list_new = []
         | 
| 189 | 
            +
                    for bno in range(len(bbox_list)):
         | 
| 190 | 
            +
                        left, top, right, bottom = bbox_list[bno].copy()
         | 
| 191 | 
            +
                        left = int(left * ratio)
         | 
| 192 | 
            +
                        top = int(top * ratio)
         | 
| 193 | 
            +
                        right = int(right * ratio)
         | 
| 194 | 
            +
                        bottom = int(bottom * ratio)
         | 
| 195 | 
            +
                        bbox_list_new.append([left, top, right, bottom])
         | 
| 196 | 
            +
                    return img_new, bbox_list_new
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def __call__(self, data):
         | 
| 199 | 
            +
                    img = data["image"]
         | 
| 200 | 
            +
                    if "cells" not in data:
         | 
| 201 | 
            +
                        cells = []
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        cells = data["cells"]
         | 
| 204 | 
            +
                    bbox_list = self.get_img_bbox(cells)
         | 
| 205 | 
            +
                    img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
         | 
| 206 | 
            +
                    data["image"] = img_new
         | 
| 207 | 
            +
                    cell_num = len(cells)
         | 
| 208 | 
            +
                    bno = 0
         | 
| 209 | 
            +
                    for cno in range(cell_num):
         | 
| 210 | 
            +
                        if "bbox" in data["cells"][cno]:
         | 
| 211 | 
            +
                            data["cells"][cno]["bbox"] = bbox_list_new[bno]
         | 
| 212 | 
            +
                            bno += 1
         | 
| 213 | 
            +
                    data["max_len"] = self.max_len
         | 
| 214 | 
            +
                    return data
         | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            class PaddingTableImage(object):
         | 
| 218 | 
            +
                def __init__(self, **kwargs):
         | 
| 219 | 
            +
                    super(PaddingTableImage, self).__init__()
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def __call__(self, data):
         | 
| 222 | 
            +
                    img = data["image"]
         | 
| 223 | 
            +
                    max_len = data["max_len"]
         | 
| 224 | 
            +
                    padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
         | 
| 225 | 
            +
                    height, width = img.shape[0:2]
         | 
| 226 | 
            +
                    padding_img[0:height, 0:width, :] = img.copy()
         | 
| 227 | 
            +
                    data["image"] = padding_img
         | 
| 228 | 
            +
                    return data
         | 
    	
        ocr/ppocr/data/imaug/iaa_augment.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import imgaug
         | 
| 4 | 
            +
            import imgaug.augmenters as iaa
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class AugmenterBuilder(object):
         | 
| 9 | 
            +
                def __init__(self):
         | 
| 10 | 
            +
                    pass
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def build(self, args, root=True):
         | 
| 13 | 
            +
                    if args is None or len(args) == 0:
         | 
| 14 | 
            +
                        return None
         | 
| 15 | 
            +
                    elif isinstance(args, list):
         | 
| 16 | 
            +
                        if root:
         | 
| 17 | 
            +
                            sequence = [self.build(value, root=False) for value in args]
         | 
| 18 | 
            +
                            return iaa.Sequential(sequence)
         | 
| 19 | 
            +
                        else:
         | 
| 20 | 
            +
                            return getattr(iaa, args[0])(
         | 
| 21 | 
            +
                                *[self.to_tuple_if_list(a) for a in args[1:]]
         | 
| 22 | 
            +
                            )
         | 
| 23 | 
            +
                    elif isinstance(args, dict):
         | 
| 24 | 
            +
                        cls = getattr(iaa, args["type"])
         | 
| 25 | 
            +
                        return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        raise RuntimeError("unknown augmenter arg: " + str(args))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def to_tuple_if_list(self, obj):
         | 
| 30 | 
            +
                    if isinstance(obj, list):
         | 
| 31 | 
            +
                        return tuple(obj)
         | 
| 32 | 
            +
                    return obj
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class IaaAugment:
         | 
| 36 | 
            +
                def __init__(self, augmenter_args=None, **kwargs):
         | 
| 37 | 
            +
                    if augmenter_args is None:
         | 
| 38 | 
            +
                        augmenter_args = [
         | 
| 39 | 
            +
                            {"type": "Fliplr", "args": {"p": 0.5}},
         | 
| 40 | 
            +
                            {"type": "Affine", "args": {"rotate": [-10, 10]}},
         | 
| 41 | 
            +
                            {"type": "Resize", "args": {"size": [0.5, 3]}},
         | 
| 42 | 
            +
                        ]
         | 
| 43 | 
            +
                    self.augmenter = AugmenterBuilder().build(augmenter_args)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __call__(self, data):
         | 
| 46 | 
            +
                    image = data["image"]
         | 
| 47 | 
            +
                    shape = image.shape
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if self.augmenter:
         | 
| 50 | 
            +
                        aug = self.augmenter.to_deterministic()
         | 
| 51 | 
            +
                        data["image"] = aug.augment_image(image)
         | 
| 52 | 
            +
                        data = self.may_augment_annotation(aug, data, shape)
         | 
| 53 | 
            +
                    return data
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def may_augment_annotation(self, aug, data, shape):
         | 
| 56 | 
            +
                    if aug is None:
         | 
| 57 | 
            +
                        return data
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    line_polys = []
         | 
| 60 | 
            +
                    for poly in data["polys"]:
         | 
| 61 | 
            +
                        new_poly = self.may_augment_poly(aug, shape, poly)
         | 
| 62 | 
            +
                        line_polys.append(new_poly)
         | 
| 63 | 
            +
                    data["polys"] = np.array(line_polys)
         | 
| 64 | 
            +
                    return data
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def may_augment_poly(self, aug, img_shape, poly):
         | 
| 67 | 
            +
                    keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
         | 
| 68 | 
            +
                    keypoints = aug.augment_keypoints(
         | 
| 69 | 
            +
                        [imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
         | 
| 70 | 
            +
                    )[0].keypoints
         | 
| 71 | 
            +
                    poly = [(p.x, p.y) for p in keypoints]
         | 
| 72 | 
            +
                    return poly
         | 
    	
        ocr/ppocr/data/imaug/label_ops.py
    ADDED
    
    | @@ -0,0 +1,1046 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from shapely.geometry import LineString, Point, Polygon
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class ClsLabelEncode(object):
         | 
| 11 | 
            +
                def __init__(self, label_list, **kwargs):
         | 
| 12 | 
            +
                    self.label_list = label_list
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __call__(self, data):
         | 
| 15 | 
            +
                    label = data["label"]
         | 
| 16 | 
            +
                    if label not in self.label_list:
         | 
| 17 | 
            +
                        return None
         | 
| 18 | 
            +
                    label = self.label_list.index(label)
         | 
| 19 | 
            +
                    data["label"] = label
         | 
| 20 | 
            +
                    return data
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class DetLabelEncode(object):
         | 
| 24 | 
            +
                def __init__(self, **kwargs):
         | 
| 25 | 
            +
                    pass
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __call__(self, data):
         | 
| 28 | 
            +
                    label = data["label"]
         | 
| 29 | 
            +
                    label = json.loads(label)
         | 
| 30 | 
            +
                    nBox = len(label)
         | 
| 31 | 
            +
                    boxes, txts, txt_tags = [], [], []
         | 
| 32 | 
            +
                    for bno in range(0, nBox):
         | 
| 33 | 
            +
                        box = label[bno]["points"]
         | 
| 34 | 
            +
                        txt = label[bno]["transcription"]
         | 
| 35 | 
            +
                        boxes.append(box)
         | 
| 36 | 
            +
                        txts.append(txt)
         | 
| 37 | 
            +
                        if txt in ["*", "###"]:
         | 
| 38 | 
            +
                            txt_tags.append(True)
         | 
| 39 | 
            +
                        else:
         | 
| 40 | 
            +
                            txt_tags.append(False)
         | 
| 41 | 
            +
                    if len(boxes) == 0:
         | 
| 42 | 
            +
                        return None
         | 
| 43 | 
            +
                    boxes = self.expand_points_num(boxes)
         | 
| 44 | 
            +
                    boxes = np.array(boxes, dtype=np.float32)
         | 
| 45 | 
            +
                    txt_tags = np.array(txt_tags, dtype=np.bool)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    data["polys"] = boxes
         | 
| 48 | 
            +
                    data["texts"] = txts
         | 
| 49 | 
            +
                    data["ignore_tags"] = txt_tags
         | 
| 50 | 
            +
                    return data
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def order_points_clockwise(self, pts):
         | 
| 53 | 
            +
                    rect = np.zeros((4, 2), dtype="float32")
         | 
| 54 | 
            +
                    s = pts.sum(axis=1)
         | 
| 55 | 
            +
                    rect[0] = pts[np.argmin(s)]
         | 
| 56 | 
            +
                    rect[2] = pts[np.argmax(s)]
         | 
| 57 | 
            +
                    diff = np.diff(pts, axis=1)
         | 
| 58 | 
            +
                    rect[1] = pts[np.argmin(diff)]
         | 
| 59 | 
            +
                    rect[3] = pts[np.argmax(diff)]
         | 
| 60 | 
            +
                    return rect
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def expand_points_num(self, boxes):
         | 
| 63 | 
            +
                    max_points_num = 0
         | 
| 64 | 
            +
                    for box in boxes:
         | 
| 65 | 
            +
                        if len(box) > max_points_num:
         | 
| 66 | 
            +
                            max_points_num = len(box)
         | 
| 67 | 
            +
                    ex_boxes = []
         | 
| 68 | 
            +
                    for box in boxes:
         | 
| 69 | 
            +
                        ex_box = box + [box[-1]] * (max_points_num - len(box))
         | 
| 70 | 
            +
                        ex_boxes.append(ex_box)
         | 
| 71 | 
            +
                    return ex_boxes
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            class BaseRecLabelEncode(object):
         | 
| 75 | 
            +
                """Convert between text-label and text-index"""
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def __init__(self, max_text_length, character_dict_path=None, use_space_char=False):
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.max_text_len = max_text_length
         | 
| 80 | 
            +
                    self.beg_str = "sos"
         | 
| 81 | 
            +
                    self.end_str = "eos"
         | 
| 82 | 
            +
                    self.lower = False
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if character_dict_path is None:
         | 
| 85 | 
            +
                        self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
         | 
| 86 | 
            +
                        dict_character = list(self.character_str)
         | 
| 87 | 
            +
                        self.lower = True
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        self.character_str = []
         | 
| 90 | 
            +
                        with open(character_dict_path, "rb") as fin:
         | 
| 91 | 
            +
                            lines = fin.readlines()
         | 
| 92 | 
            +
                            for line in lines:
         | 
| 93 | 
            +
                                line = line.decode("utf-8").strip("\n").strip("\r\n")
         | 
| 94 | 
            +
                                self.character_str.append(line)
         | 
| 95 | 
            +
                        if use_space_char:
         | 
| 96 | 
            +
                            self.character_str.append(" ")
         | 
| 97 | 
            +
                        dict_character = list(self.character_str)
         | 
| 98 | 
            +
                    dict_character = self.add_special_char(dict_character)
         | 
| 99 | 
            +
                    self.dict = {}
         | 
| 100 | 
            +
                    for i, char in enumerate(dict_character):
         | 
| 101 | 
            +
                        self.dict[char] = i
         | 
| 102 | 
            +
                    self.character = dict_character
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def add_special_char(self, dict_character):
         | 
| 105 | 
            +
                    return dict_character
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def encode(self, text):
         | 
| 108 | 
            +
                    """convert text-label into text-index.
         | 
| 109 | 
            +
                    input:
         | 
| 110 | 
            +
                        text: text labels of each image. [batch_size]
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    output:
         | 
| 113 | 
            +
                        text: concatenated text index for CTCLoss.
         | 
| 114 | 
            +
                                [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
         | 
| 115 | 
            +
                        length: length of each text. [batch_size]
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    if len(text) == 0 or len(text) > self.max_text_len:
         | 
| 118 | 
            +
                        return None
         | 
| 119 | 
            +
                    if self.lower:
         | 
| 120 | 
            +
                        text = text.lower()
         | 
| 121 | 
            +
                    text_list = []
         | 
| 122 | 
            +
                    for char in text:
         | 
| 123 | 
            +
                        if char not in self.dict:
         | 
| 124 | 
            +
                            continue
         | 
| 125 | 
            +
                        text_list.append(self.dict[char])
         | 
| 126 | 
            +
                    if len(text_list) == 0:
         | 
| 127 | 
            +
                        return None
         | 
| 128 | 
            +
                    return text_list
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            class NRTRLabelEncode(BaseRecLabelEncode):
         | 
| 132 | 
            +
                """Convert between text-label and text-index"""
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __init__(
         | 
| 135 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 136 | 
            +
                ):
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    super(NRTRLabelEncode, self).__init__(
         | 
| 139 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __call__(self, data):
         | 
| 143 | 
            +
                    text = data["label"]
         | 
| 144 | 
            +
                    text = self.encode(text)
         | 
| 145 | 
            +
                    if text is None:
         | 
| 146 | 
            +
                        return None
         | 
| 147 | 
            +
                    if len(text) >= self.max_text_len - 1:
         | 
| 148 | 
            +
                        return None
         | 
| 149 | 
            +
                    data["length"] = np.array(len(text))
         | 
| 150 | 
            +
                    text.insert(0, 2)
         | 
| 151 | 
            +
                    text.append(3)
         | 
| 152 | 
            +
                    text = text + [0] * (self.max_text_len - len(text))
         | 
| 153 | 
            +
                    data["label"] = np.array(text)
         | 
| 154 | 
            +
                    return data
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def add_special_char(self, dict_character):
         | 
| 157 | 
            +
                    dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
         | 
| 158 | 
            +
                    return dict_character
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            class CTCLabelEncode(BaseRecLabelEncode):
         | 
| 162 | 
            +
                """Convert between text-label and text-index"""
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def __init__(
         | 
| 165 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 166 | 
            +
                ):
         | 
| 167 | 
            +
                    super(CTCLabelEncode, self).__init__(
         | 
| 168 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def __call__(self, data):
         | 
| 172 | 
            +
                    text = data["label"]
         | 
| 173 | 
            +
                    text = self.encode(text)
         | 
| 174 | 
            +
                    if text is None:
         | 
| 175 | 
            +
                        return None
         | 
| 176 | 
            +
                    data["length"] = np.array(len(text))
         | 
| 177 | 
            +
                    text = text + [0] * (self.max_text_len - len(text))
         | 
| 178 | 
            +
                    data["label"] = np.array(text)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    label = [0] * len(self.character)
         | 
| 181 | 
            +
                    for x in text:
         | 
| 182 | 
            +
                        label[x] += 1
         | 
| 183 | 
            +
                    data["label_ace"] = np.array(label)
         | 
| 184 | 
            +
                    return data
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def add_special_char(self, dict_character):
         | 
| 187 | 
            +
                    dict_character = ["blank"] + dict_character
         | 
| 188 | 
            +
                    return dict_character
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            class E2ELabelEncodeTest(BaseRecLabelEncode):
         | 
| 192 | 
            +
                def __init__(
         | 
| 193 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 194 | 
            +
                ):
         | 
| 195 | 
            +
                    super(E2ELabelEncodeTest, self).__init__(
         | 
| 196 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def __call__(self, data):
         | 
| 200 | 
            +
                    import json
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    padnum = len(self.dict)
         | 
| 203 | 
            +
                    label = data["label"]
         | 
| 204 | 
            +
                    label = json.loads(label)
         | 
| 205 | 
            +
                    nBox = len(label)
         | 
| 206 | 
            +
                    boxes, txts, txt_tags = [], [], []
         | 
| 207 | 
            +
                    for bno in range(0, nBox):
         | 
| 208 | 
            +
                        box = label[bno]["points"]
         | 
| 209 | 
            +
                        txt = label[bno]["transcription"]
         | 
| 210 | 
            +
                        boxes.append(box)
         | 
| 211 | 
            +
                        txts.append(txt)
         | 
| 212 | 
            +
                        if txt in ["*", "###"]:
         | 
| 213 | 
            +
                            txt_tags.append(True)
         | 
| 214 | 
            +
                        else:
         | 
| 215 | 
            +
                            txt_tags.append(False)
         | 
| 216 | 
            +
                    boxes = np.array(boxes, dtype=np.float32)
         | 
| 217 | 
            +
                    txt_tags = np.array(txt_tags, dtype=np.bool)
         | 
| 218 | 
            +
                    data["polys"] = boxes
         | 
| 219 | 
            +
                    data["ignore_tags"] = txt_tags
         | 
| 220 | 
            +
                    temp_texts = []
         | 
| 221 | 
            +
                    for text in txts:
         | 
| 222 | 
            +
                        text = text.lower()
         | 
| 223 | 
            +
                        text = self.encode(text)
         | 
| 224 | 
            +
                        if text is None:
         | 
| 225 | 
            +
                            return None
         | 
| 226 | 
            +
                        text = text + [padnum] * (self.max_text_len - len(text))  # use 36 to pad
         | 
| 227 | 
            +
                        temp_texts.append(text)
         | 
| 228 | 
            +
                    data["texts"] = np.array(temp_texts)
         | 
| 229 | 
            +
                    return data
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            class E2ELabelEncodeTrain(object):
         | 
| 233 | 
            +
                def __init__(self, **kwargs):
         | 
| 234 | 
            +
                    pass
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def __call__(self, data):
         | 
| 237 | 
            +
                    import json
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    label = data["label"]
         | 
| 240 | 
            +
                    label = json.loads(label)
         | 
| 241 | 
            +
                    nBox = len(label)
         | 
| 242 | 
            +
                    boxes, txts, txt_tags = [], [], []
         | 
| 243 | 
            +
                    for bno in range(0, nBox):
         | 
| 244 | 
            +
                        box = label[bno]["points"]
         | 
| 245 | 
            +
                        txt = label[bno]["transcription"]
         | 
| 246 | 
            +
                        boxes.append(box)
         | 
| 247 | 
            +
                        txts.append(txt)
         | 
| 248 | 
            +
                        if txt in ["*", "###"]:
         | 
| 249 | 
            +
                            txt_tags.append(True)
         | 
| 250 | 
            +
                        else:
         | 
| 251 | 
            +
                            txt_tags.append(False)
         | 
| 252 | 
            +
                    boxes = np.array(boxes, dtype=np.float32)
         | 
| 253 | 
            +
                    txt_tags = np.array(txt_tags, dtype=np.bool)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    data["polys"] = boxes
         | 
| 256 | 
            +
                    data["texts"] = txts
         | 
| 257 | 
            +
                    data["ignore_tags"] = txt_tags
         | 
| 258 | 
            +
                    return data
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            class KieLabelEncode(object):
         | 
| 262 | 
            +
                def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
         | 
| 263 | 
            +
                    super(KieLabelEncode, self).__init__()
         | 
| 264 | 
            +
                    self.dict = dict({"": 0})
         | 
| 265 | 
            +
                    with open(character_dict_path, "r", encoding="utf-8") as fr:
         | 
| 266 | 
            +
                        idx = 1
         | 
| 267 | 
            +
                        for line in fr:
         | 
| 268 | 
            +
                            char = line.strip()
         | 
| 269 | 
            +
                            self.dict[char] = idx
         | 
| 270 | 
            +
                            idx += 1
         | 
| 271 | 
            +
                    self.norm = norm
         | 
| 272 | 
            +
                    self.directed = directed
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def compute_relation(self, boxes):
         | 
| 275 | 
            +
                    """Compute relation between every two boxes."""
         | 
| 276 | 
            +
                    x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
         | 
| 277 | 
            +
                    x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
         | 
| 278 | 
            +
                    ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
         | 
| 279 | 
            +
                    dxs = (x1s[:, 0][None] - x1s) / self.norm
         | 
| 280 | 
            +
                    dys = (y1s[:, 0][None] - y1s) / self.norm
         | 
| 281 | 
            +
                    xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
         | 
| 282 | 
            +
                    whs = ws / hs + np.zeros_like(xhhs)
         | 
| 283 | 
            +
                    relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
         | 
| 284 | 
            +
                    bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
         | 
| 285 | 
            +
                    return relations, bboxes
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                def pad_text_indices(self, text_inds):
         | 
| 288 | 
            +
                    """Pad text index to same length."""
         | 
| 289 | 
            +
                    max_len = 300
         | 
| 290 | 
            +
                    recoder_len = max([len(text_ind) for text_ind in text_inds])
         | 
| 291 | 
            +
                    padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
         | 
| 292 | 
            +
                    for idx, text_ind in enumerate(text_inds):
         | 
| 293 | 
            +
                        padded_text_inds[idx, : len(text_ind)] = np.array(text_ind)
         | 
| 294 | 
            +
                    return padded_text_inds, recoder_len
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                def list_to_numpy(self, ann_infos):
         | 
| 297 | 
            +
                    """Convert bboxes, relations, texts and labels to ndarray."""
         | 
| 298 | 
            +
                    boxes, text_inds = ann_infos["points"], ann_infos["text_inds"]
         | 
| 299 | 
            +
                    boxes = np.array(boxes, np.int32)
         | 
| 300 | 
            +
                    relations, bboxes = self.compute_relation(boxes)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    labels = ann_infos.get("labels", None)
         | 
| 303 | 
            +
                    if labels is not None:
         | 
| 304 | 
            +
                        labels = np.array(labels, np.int32)
         | 
| 305 | 
            +
                        edges = ann_infos.get("edges", None)
         | 
| 306 | 
            +
                        if edges is not None:
         | 
| 307 | 
            +
                            labels = labels[:, None]
         | 
| 308 | 
            +
                            edges = np.array(edges)
         | 
| 309 | 
            +
                            edges = (edges[:, None] == edges[None, :]).astype(np.int32)
         | 
| 310 | 
            +
                            if self.directed:
         | 
| 311 | 
            +
                                edges = (edges & labels == 1).astype(np.int32)
         | 
| 312 | 
            +
                            np.fill_diagonal(edges, -1)
         | 
| 313 | 
            +
                            labels = np.concatenate([labels, edges], -1)
         | 
| 314 | 
            +
                    padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
         | 
| 315 | 
            +
                    max_num = 300
         | 
| 316 | 
            +
                    temp_bboxes = np.zeros([max_num, 4])
         | 
| 317 | 
            +
                    h, _ = bboxes.shape
         | 
| 318 | 
            +
                    temp_bboxes[:h, :] = bboxes
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    temp_relations = np.zeros([max_num, max_num, 5])
         | 
| 321 | 
            +
                    temp_relations[:h, :h, :] = relations
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    temp_padded_text_inds = np.zeros([max_num, max_num])
         | 
| 324 | 
            +
                    temp_padded_text_inds[:h, :] = padded_text_inds
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    temp_labels = np.zeros([max_num, max_num])
         | 
| 327 | 
            +
                    temp_labels[:h, : h + 1] = labels
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    tag = np.array([h, recoder_len])
         | 
| 330 | 
            +
                    return dict(
         | 
| 331 | 
            +
                        image=ann_infos["image"],
         | 
| 332 | 
            +
                        points=temp_bboxes,
         | 
| 333 | 
            +
                        relations=temp_relations,
         | 
| 334 | 
            +
                        texts=temp_padded_text_inds,
         | 
| 335 | 
            +
                        labels=temp_labels,
         | 
| 336 | 
            +
                        tag=tag,
         | 
| 337 | 
            +
                    )
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def convert_canonical(self, points_x, points_y):
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    assert len(points_x) == 4
         | 
| 342 | 
            +
                    assert len(points_y) == 4
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    points = [Point(points_x[i], points_y[i]) for i in range(4)]
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    polygon = Polygon([(p.x, p.y) for p in points])
         | 
| 347 | 
            +
                    min_x, min_y, _, _ = polygon.bounds
         | 
| 348 | 
            +
                    points_to_lefttop = [
         | 
| 349 | 
            +
                        LineString([points[i], Point(min_x, min_y)]) for i in range(4)
         | 
| 350 | 
            +
                    ]
         | 
| 351 | 
            +
                    distances = np.array([line.length for line in points_to_lefttop])
         | 
| 352 | 
            +
                    sort_dist_idx = np.argsort(distances)
         | 
| 353 | 
            +
                    lefttop_idx = sort_dist_idx[0]
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    if lefttop_idx == 0:
         | 
| 356 | 
            +
                        point_orders = [0, 1, 2, 3]
         | 
| 357 | 
            +
                    elif lefttop_idx == 1:
         | 
| 358 | 
            +
                        point_orders = [1, 2, 3, 0]
         | 
| 359 | 
            +
                    elif lefttop_idx == 2:
         | 
| 360 | 
            +
                        point_orders = [2, 3, 0, 1]
         | 
| 361 | 
            +
                    else:
         | 
| 362 | 
            +
                        point_orders = [3, 0, 1, 2]
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    sorted_points_x = [points_x[i] for i in point_orders]
         | 
| 365 | 
            +
                    sorted_points_y = [points_y[j] for j in point_orders]
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    return sorted_points_x, sorted_points_y
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                def sort_vertex(self, points_x, points_y):
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    assert len(points_x) == 4
         | 
| 372 | 
            +
                    assert len(points_y) == 4
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    x = np.array(points_x)
         | 
| 375 | 
            +
                    y = np.array(points_y)
         | 
| 376 | 
            +
                    center_x = np.sum(x) * 0.25
         | 
| 377 | 
            +
                    center_y = np.sum(y) * 0.25
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    x_arr = np.array(x - center_x)
         | 
| 380 | 
            +
                    y_arr = np.array(y - center_y)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
         | 
| 383 | 
            +
                    sort_idx = np.argsort(angle)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    sorted_points_x, sorted_points_y = [], []
         | 
| 386 | 
            +
                    for i in range(4):
         | 
| 387 | 
            +
                        sorted_points_x.append(points_x[sort_idx[i]])
         | 
| 388 | 
            +
                        sorted_points_y.append(points_y[sort_idx[i]])
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    return self.convert_canonical(sorted_points_x, sorted_points_y)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                def __call__(self, data):
         | 
| 393 | 
            +
                    import json
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    label = data["label"]
         | 
| 396 | 
            +
                    annotations = json.loads(label)
         | 
| 397 | 
            +
                    boxes, texts, text_inds, labels, edges = [], [], [], [], []
         | 
| 398 | 
            +
                    for ann in annotations:
         | 
| 399 | 
            +
                        box = ann["points"]
         | 
| 400 | 
            +
                        x_list = [box[i][0] for i in range(4)]
         | 
| 401 | 
            +
                        y_list = [box[i][1] for i in range(4)]
         | 
| 402 | 
            +
                        sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
         | 
| 403 | 
            +
                        sorted_box = []
         | 
| 404 | 
            +
                        for x, y in zip(sorted_x_list, sorted_y_list):
         | 
| 405 | 
            +
                            sorted_box.append(x)
         | 
| 406 | 
            +
                            sorted_box.append(y)
         | 
| 407 | 
            +
                        boxes.append(sorted_box)
         | 
| 408 | 
            +
                        text = ann["transcription"]
         | 
| 409 | 
            +
                        texts.append(ann["transcription"])
         | 
| 410 | 
            +
                        text_ind = [self.dict[c] for c in text if c in self.dict]
         | 
| 411 | 
            +
                        text_inds.append(text_ind)
         | 
| 412 | 
            +
                        if "label" in ann.keys():
         | 
| 413 | 
            +
                            labels.append(ann["label"])
         | 
| 414 | 
            +
                        elif "key_cls" in ann.keys():
         | 
| 415 | 
            +
                            labels.append(ann["key_cls"])
         | 
| 416 | 
            +
                        else:
         | 
| 417 | 
            +
                            raise ValueError(
         | 
| 418 | 
            +
                                "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
         | 
| 419 | 
            +
                            )
         | 
| 420 | 
            +
                        edges.append(ann.get("edge", 0))
         | 
| 421 | 
            +
                    ann_infos = dict(
         | 
| 422 | 
            +
                        image=data["image"],
         | 
| 423 | 
            +
                        points=boxes,
         | 
| 424 | 
            +
                        texts=texts,
         | 
| 425 | 
            +
                        text_inds=text_inds,
         | 
| 426 | 
            +
                        edges=edges,
         | 
| 427 | 
            +
                        labels=labels,
         | 
| 428 | 
            +
                    )
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return self.list_to_numpy(ann_infos)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            class AttnLabelEncode(BaseRecLabelEncode):
         | 
| 434 | 
            +
                """Convert between text-label and text-index"""
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                def __init__(
         | 
| 437 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 438 | 
            +
                ):
         | 
| 439 | 
            +
                    super(AttnLabelEncode, self).__init__(
         | 
| 440 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 441 | 
            +
                    )
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                def add_special_char(self, dict_character):
         | 
| 444 | 
            +
                    self.beg_str = "sos"
         | 
| 445 | 
            +
                    self.end_str = "eos"
         | 
| 446 | 
            +
                    dict_character = [self.beg_str] + dict_character + [self.end_str]
         | 
| 447 | 
            +
                    return dict_character
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                def __call__(self, data):
         | 
| 450 | 
            +
                    text = data["label"]
         | 
| 451 | 
            +
                    text = self.encode(text)
         | 
| 452 | 
            +
                    if text is None:
         | 
| 453 | 
            +
                        return None
         | 
| 454 | 
            +
                    if len(text) >= self.max_text_len:
         | 
| 455 | 
            +
                        return None
         | 
| 456 | 
            +
                    data["length"] = np.array(len(text))
         | 
| 457 | 
            +
                    text = (
         | 
| 458 | 
            +
                        [0]
         | 
| 459 | 
            +
                        + text
         | 
| 460 | 
            +
                        + [len(self.character) - 1]
         | 
| 461 | 
            +
                        + [0] * (self.max_text_len - len(text) - 2)
         | 
| 462 | 
            +
                    )
         | 
| 463 | 
            +
                    data["label"] = np.array(text)
         | 
| 464 | 
            +
                    return data
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                def get_ignored_tokens(self):
         | 
| 467 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg")
         | 
| 468 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end")
         | 
| 469 | 
            +
                    return [beg_idx, end_idx]
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end):
         | 
| 472 | 
            +
                    if beg_or_end == "beg":
         | 
| 473 | 
            +
                        idx = np.array(self.dict[self.beg_str])
         | 
| 474 | 
            +
                    elif beg_or_end == "end":
         | 
| 475 | 
            +
                        idx = np.array(self.dict[self.end_str])
         | 
| 476 | 
            +
                    else:
         | 
| 477 | 
            +
                        assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
         | 
| 478 | 
            +
                    return idx
         | 
| 479 | 
            +
             | 
| 480 | 
            +
             | 
| 481 | 
            +
            class SEEDLabelEncode(BaseRecLabelEncode):
         | 
| 482 | 
            +
                """Convert between text-label and text-index"""
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                def __init__(
         | 
| 485 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 486 | 
            +
                ):
         | 
| 487 | 
            +
                    super(SEEDLabelEncode, self).__init__(
         | 
| 488 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 489 | 
            +
                    )
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                def add_special_char(self, dict_character):
         | 
| 492 | 
            +
                    self.padding = "padding"
         | 
| 493 | 
            +
                    self.end_str = "eos"
         | 
| 494 | 
            +
                    self.unknown = "unknown"
         | 
| 495 | 
            +
                    dict_character = dict_character + [self.end_str, self.padding, self.unknown]
         | 
| 496 | 
            +
                    return dict_character
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                def __call__(self, data):
         | 
| 499 | 
            +
                    text = data["label"]
         | 
| 500 | 
            +
                    text = self.encode(text)
         | 
| 501 | 
            +
                    if text is None:
         | 
| 502 | 
            +
                        return None
         | 
| 503 | 
            +
                    if len(text) >= self.max_text_len:
         | 
| 504 | 
            +
                        return None
         | 
| 505 | 
            +
                    data["length"] = np.array(len(text)) + 1  # conclude eos
         | 
| 506 | 
            +
                    text = (
         | 
| 507 | 
            +
                        text
         | 
| 508 | 
            +
                        + [len(self.character) - 3]
         | 
| 509 | 
            +
                        + [len(self.character) - 2] * (self.max_text_len - len(text) - 1)
         | 
| 510 | 
            +
                    )
         | 
| 511 | 
            +
                    data["label"] = np.array(text)
         | 
| 512 | 
            +
                    return data
         | 
| 513 | 
            +
             | 
| 514 | 
            +
             | 
| 515 | 
            +
            class SRNLabelEncode(BaseRecLabelEncode):
         | 
| 516 | 
            +
                """Convert between text-label and text-index"""
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def __init__(
         | 
| 519 | 
            +
                    self,
         | 
| 520 | 
            +
                    max_text_length=25,
         | 
| 521 | 
            +
                    character_dict_path=None,
         | 
| 522 | 
            +
                    use_space_char=False,
         | 
| 523 | 
            +
                    **kwargs
         | 
| 524 | 
            +
                ):
         | 
| 525 | 
            +
                    super(SRNLabelEncode, self).__init__(
         | 
| 526 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 527 | 
            +
                    )
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                def add_special_char(self, dict_character):
         | 
| 530 | 
            +
                    dict_character = dict_character + [self.beg_str, self.end_str]
         | 
| 531 | 
            +
                    return dict_character
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                def __call__(self, data):
         | 
| 534 | 
            +
                    text = data["label"]
         | 
| 535 | 
            +
                    text = self.encode(text)
         | 
| 536 | 
            +
                    char_num = len(self.character)
         | 
| 537 | 
            +
                    if text is None:
         | 
| 538 | 
            +
                        return None
         | 
| 539 | 
            +
                    if len(text) > self.max_text_len:
         | 
| 540 | 
            +
                        return None
         | 
| 541 | 
            +
                    data["length"] = np.array(len(text))
         | 
| 542 | 
            +
                    text = text + [char_num - 1] * (self.max_text_len - len(text))
         | 
| 543 | 
            +
                    data["label"] = np.array(text)
         | 
| 544 | 
            +
                    return data
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                def get_ignored_tokens(self):
         | 
| 547 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg")
         | 
| 548 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end")
         | 
| 549 | 
            +
                    return [beg_idx, end_idx]
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end):
         | 
| 552 | 
            +
                    if beg_or_end == "beg":
         | 
| 553 | 
            +
                        idx = np.array(self.dict[self.beg_str])
         | 
| 554 | 
            +
                    elif beg_or_end == "end":
         | 
| 555 | 
            +
                        idx = np.array(self.dict[self.end_str])
         | 
| 556 | 
            +
                    else:
         | 
| 557 | 
            +
                        assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
         | 
| 558 | 
            +
                    return idx
         | 
| 559 | 
            +
             | 
| 560 | 
            +
             | 
| 561 | 
            +
            class TableLabelEncode(object):
         | 
| 562 | 
            +
                """Convert between text-label and text-index"""
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                def __init__(
         | 
| 565 | 
            +
                    self,
         | 
| 566 | 
            +
                    max_text_length,
         | 
| 567 | 
            +
                    max_elem_length,
         | 
| 568 | 
            +
                    max_cell_num,
         | 
| 569 | 
            +
                    character_dict_path,
         | 
| 570 | 
            +
                    span_weight=1.0,
         | 
| 571 | 
            +
                    **kwargs
         | 
| 572 | 
            +
                ):
         | 
| 573 | 
            +
                    self.max_text_length = max_text_length
         | 
| 574 | 
            +
                    self.max_elem_length = max_elem_length
         | 
| 575 | 
            +
                    self.max_cell_num = max_cell_num
         | 
| 576 | 
            +
                    list_character, list_elem = self.load_char_elem_dict(character_dict_path)
         | 
| 577 | 
            +
                    list_character = self.add_special_char(list_character)
         | 
| 578 | 
            +
                    list_elem = self.add_special_char(list_elem)
         | 
| 579 | 
            +
                    self.dict_character = {}
         | 
| 580 | 
            +
                    for i, char in enumerate(list_character):
         | 
| 581 | 
            +
                        self.dict_character[char] = i
         | 
| 582 | 
            +
                    self.dict_elem = {}
         | 
| 583 | 
            +
                    for i, elem in enumerate(list_elem):
         | 
| 584 | 
            +
                        self.dict_elem[elem] = i
         | 
| 585 | 
            +
                    self.span_weight = span_weight
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                def load_char_elem_dict(self, character_dict_path):
         | 
| 588 | 
            +
                    list_character = []
         | 
| 589 | 
            +
                    list_elem = []
         | 
| 590 | 
            +
                    with open(character_dict_path, "rb") as fin:
         | 
| 591 | 
            +
                        lines = fin.readlines()
         | 
| 592 | 
            +
                        substr = lines[0].decode("utf-8").strip("\r\n").split("\t")
         | 
| 593 | 
            +
                        character_num = int(substr[0])
         | 
| 594 | 
            +
                        elem_num = int(substr[1])
         | 
| 595 | 
            +
                        for cno in range(1, 1 + character_num):
         | 
| 596 | 
            +
                            character = lines[cno].decode("utf-8").strip("\r\n")
         | 
| 597 | 
            +
                            list_character.append(character)
         | 
| 598 | 
            +
                        for eno in range(1 + character_num, 1 + character_num + elem_num):
         | 
| 599 | 
            +
                            elem = lines[eno].decode("utf-8").strip("\r\n")
         | 
| 600 | 
            +
                            list_elem.append(elem)
         | 
| 601 | 
            +
                    return list_character, list_elem
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                def add_special_char(self, list_character):
         | 
| 604 | 
            +
                    self.beg_str = "sos"
         | 
| 605 | 
            +
                    self.end_str = "eos"
         | 
| 606 | 
            +
                    list_character = [self.beg_str] + list_character + [self.end_str]
         | 
| 607 | 
            +
                    return list_character
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                def get_span_idx_list(self):
         | 
| 610 | 
            +
                    span_idx_list = []
         | 
| 611 | 
            +
                    for elem in self.dict_elem:
         | 
| 612 | 
            +
                        if "span" in elem:
         | 
| 613 | 
            +
                            span_idx_list.append(self.dict_elem[elem])
         | 
| 614 | 
            +
                    return span_idx_list
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                def __call__(self, data):
         | 
| 617 | 
            +
                    cells = data["cells"]
         | 
| 618 | 
            +
                    structure = data["structure"]["tokens"]
         | 
| 619 | 
            +
                    structure = self.encode(structure, "elem")
         | 
| 620 | 
            +
                    if structure is None:
         | 
| 621 | 
            +
                        return None
         | 
| 622 | 
            +
                    elem_num = len(structure)
         | 
| 623 | 
            +
                    structure = [0] + structure + [len(self.dict_elem) - 1]
         | 
| 624 | 
            +
                    structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
         | 
| 625 | 
            +
                    structure = np.array(structure)
         | 
| 626 | 
            +
                    data["structure"] = structure
         | 
| 627 | 
            +
                    elem_char_idx1 = self.dict_elem["<td>"]
         | 
| 628 | 
            +
                    elem_char_idx2 = self.dict_elem["<td"]
         | 
| 629 | 
            +
                    span_idx_list = self.get_span_idx_list()
         | 
| 630 | 
            +
                    td_idx_list = np.logical_or(
         | 
| 631 | 
            +
                        structure == elem_char_idx1, structure == elem_char_idx2
         | 
| 632 | 
            +
                    )
         | 
| 633 | 
            +
                    td_idx_list = np.where(td_idx_list)[0]
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
         | 
| 636 | 
            +
                    bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
         | 
| 637 | 
            +
                    bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
         | 
| 638 | 
            +
                    img_height, img_width, img_ch = data["image"].shape
         | 
| 639 | 
            +
                    if len(span_idx_list) > 0:
         | 
| 640 | 
            +
                        span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
         | 
| 641 | 
            +
                        span_weight = min(max(span_weight, 1.0), self.span_weight)
         | 
| 642 | 
            +
                    for cno in range(len(cells)):
         | 
| 643 | 
            +
                        if "bbox" in cells[cno]:
         | 
| 644 | 
            +
                            bbox = cells[cno]["bbox"].copy()
         | 
| 645 | 
            +
                            bbox[0] = bbox[0] * 1.0 / img_width
         | 
| 646 | 
            +
                            bbox[1] = bbox[1] * 1.0 / img_height
         | 
| 647 | 
            +
                            bbox[2] = bbox[2] * 1.0 / img_width
         | 
| 648 | 
            +
                            bbox[3] = bbox[3] * 1.0 / img_height
         | 
| 649 | 
            +
                            td_idx = td_idx_list[cno]
         | 
| 650 | 
            +
                            bbox_list[td_idx] = bbox
         | 
| 651 | 
            +
                            bbox_list_mask[td_idx] = 1.0
         | 
| 652 | 
            +
                            cand_span_idx = td_idx + 1
         | 
| 653 | 
            +
                            if cand_span_idx < (self.max_elem_length + 2):
         | 
| 654 | 
            +
                                if structure[cand_span_idx] in span_idx_list:
         | 
| 655 | 
            +
                                    structure_mask[cand_span_idx] = span_weight
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    data["bbox_list"] = bbox_list
         | 
| 658 | 
            +
                    data["bbox_list_mask"] = bbox_list_mask
         | 
| 659 | 
            +
                    data["structure_mask"] = structure_mask
         | 
| 660 | 
            +
                    char_beg_idx = self.get_beg_end_flag_idx("beg", "char")
         | 
| 661 | 
            +
                    char_end_idx = self.get_beg_end_flag_idx("end", "char")
         | 
| 662 | 
            +
                    elem_beg_idx = self.get_beg_end_flag_idx("beg", "elem")
         | 
| 663 | 
            +
                    elem_end_idx = self.get_beg_end_flag_idx("end", "elem")
         | 
| 664 | 
            +
                    data["sp_tokens"] = np.array(
         | 
| 665 | 
            +
                        [
         | 
| 666 | 
            +
                            char_beg_idx,
         | 
| 667 | 
            +
                            char_end_idx,
         | 
| 668 | 
            +
                            elem_beg_idx,
         | 
| 669 | 
            +
                            elem_end_idx,
         | 
| 670 | 
            +
                            elem_char_idx1,
         | 
| 671 | 
            +
                            elem_char_idx2,
         | 
| 672 | 
            +
                            self.max_text_length,
         | 
| 673 | 
            +
                            self.max_elem_length,
         | 
| 674 | 
            +
                            self.max_cell_num,
         | 
| 675 | 
            +
                            elem_num,
         | 
| 676 | 
            +
                        ]
         | 
| 677 | 
            +
                    )
         | 
| 678 | 
            +
                    return data
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                def encode(self, text, char_or_elem):
         | 
| 681 | 
            +
                    """convert text-label into text-index."""
         | 
| 682 | 
            +
                    if char_or_elem == "char":
         | 
| 683 | 
            +
                        max_len = self.max_text_length
         | 
| 684 | 
            +
                        current_dict = self.dict_character
         | 
| 685 | 
            +
                    else:
         | 
| 686 | 
            +
                        max_len = self.max_elem_length
         | 
| 687 | 
            +
                        current_dict = self.dict_elem
         | 
| 688 | 
            +
                    if len(text) > max_len:
         | 
| 689 | 
            +
                        return None
         | 
| 690 | 
            +
                    if len(text) == 0:
         | 
| 691 | 
            +
                        if char_or_elem == "char":
         | 
| 692 | 
            +
                            return [self.dict_character["space"]]
         | 
| 693 | 
            +
                        else:
         | 
| 694 | 
            +
                            return None
         | 
| 695 | 
            +
                    text_list = []
         | 
| 696 | 
            +
                    for char in text:
         | 
| 697 | 
            +
                        if char not in current_dict:
         | 
| 698 | 
            +
                            return None
         | 
| 699 | 
            +
                        text_list.append(current_dict[char])
         | 
| 700 | 
            +
                    if len(text_list) == 0:
         | 
| 701 | 
            +
                        if char_or_elem == "char":
         | 
| 702 | 
            +
                            return [self.dict_character["space"]]
         | 
| 703 | 
            +
                        else:
         | 
| 704 | 
            +
                            return None
         | 
| 705 | 
            +
                    return text_list
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                def get_ignored_tokens(self, char_or_elem):
         | 
| 708 | 
            +
                    beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
         | 
| 709 | 
            +
                    end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
         | 
| 710 | 
            +
                    return [beg_idx, end_idx]
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
         | 
| 713 | 
            +
                    if char_or_elem == "char":
         | 
| 714 | 
            +
                        if beg_or_end == "beg":
         | 
| 715 | 
            +
                            idx = np.array(self.dict_character[self.beg_str])
         | 
| 716 | 
            +
                        elif beg_or_end == "end":
         | 
| 717 | 
            +
                            idx = np.array(self.dict_character[self.end_str])
         | 
| 718 | 
            +
                        else:
         | 
| 719 | 
            +
                            assert False, (
         | 
| 720 | 
            +
                                "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
         | 
| 721 | 
            +
                            )
         | 
| 722 | 
            +
                    elif char_or_elem == "elem":
         | 
| 723 | 
            +
                        if beg_or_end == "beg":
         | 
| 724 | 
            +
                            idx = np.array(self.dict_elem[self.beg_str])
         | 
| 725 | 
            +
                        elif beg_or_end == "end":
         | 
| 726 | 
            +
                            idx = np.array(self.dict_elem[self.end_str])
         | 
| 727 | 
            +
                        else:
         | 
| 728 | 
            +
                            assert False, (
         | 
| 729 | 
            +
                                "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
         | 
| 730 | 
            +
                            )
         | 
| 731 | 
            +
                    else:
         | 
| 732 | 
            +
                        assert False, "Unsupport type %s in char_or_elem" % char_or_elem
         | 
| 733 | 
            +
                    return idx
         | 
| 734 | 
            +
             | 
| 735 | 
            +
             | 
| 736 | 
            +
            class SARLabelEncode(BaseRecLabelEncode):
         | 
| 737 | 
            +
                """Convert between text-label and text-index"""
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                def __init__(
         | 
| 740 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 741 | 
            +
                ):
         | 
| 742 | 
            +
                    super(SARLabelEncode, self).__init__(
         | 
| 743 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 744 | 
            +
                    )
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                def add_special_char(self, dict_character):
         | 
| 747 | 
            +
                    beg_end_str = "<BOS/EOS>"
         | 
| 748 | 
            +
                    unknown_str = "<UKN>"
         | 
| 749 | 
            +
                    padding_str = "<PAD>"
         | 
| 750 | 
            +
                    dict_character = dict_character + [unknown_str]
         | 
| 751 | 
            +
                    self.unknown_idx = len(dict_character) - 1
         | 
| 752 | 
            +
                    dict_character = dict_character + [beg_end_str]
         | 
| 753 | 
            +
                    self.start_idx = len(dict_character) - 1
         | 
| 754 | 
            +
                    self.end_idx = len(dict_character) - 1
         | 
| 755 | 
            +
                    dict_character = dict_character + [padding_str]
         | 
| 756 | 
            +
                    self.padding_idx = len(dict_character) - 1
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                    return dict_character
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                def __call__(self, data):
         | 
| 761 | 
            +
                    text = data["label"]
         | 
| 762 | 
            +
                    text = self.encode(text)
         | 
| 763 | 
            +
                    if text is None:
         | 
| 764 | 
            +
                        return None
         | 
| 765 | 
            +
                    if len(text) >= self.max_text_len - 1:
         | 
| 766 | 
            +
                        return None
         | 
| 767 | 
            +
                    data["length"] = np.array(len(text))
         | 
| 768 | 
            +
                    target = [self.start_idx] + text + [self.end_idx]
         | 
| 769 | 
            +
                    padded_text = [self.padding_idx for _ in range(self.max_text_len)]
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    padded_text[: len(target)] = target
         | 
| 772 | 
            +
                    data["label"] = np.array(padded_text)
         | 
| 773 | 
            +
                    return data
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                def get_ignored_tokens(self):
         | 
| 776 | 
            +
                    return [self.padding_idx]
         | 
| 777 | 
            +
             | 
| 778 | 
            +
             | 
| 779 | 
            +
            class PRENLabelEncode(BaseRecLabelEncode):
         | 
| 780 | 
            +
                def __init__(
         | 
| 781 | 
            +
                    self, max_text_length, character_dict_path, use_space_char=False, **kwargs
         | 
| 782 | 
            +
                ):
         | 
| 783 | 
            +
                    super(PRENLabelEncode, self).__init__(
         | 
| 784 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 785 | 
            +
                    )
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                def add_special_char(self, dict_character):
         | 
| 788 | 
            +
                    padding_str = "<PAD>"  # 0
         | 
| 789 | 
            +
                    end_str = "<EOS>"  # 1
         | 
| 790 | 
            +
                    unknown_str = "<UNK>"  # 2
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    dict_character = [padding_str, end_str, unknown_str] + dict_character
         | 
| 793 | 
            +
                    self.padding_idx = 0
         | 
| 794 | 
            +
                    self.end_idx = 1
         | 
| 795 | 
            +
                    self.unknown_idx = 2
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                    return dict_character
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                def encode(self, text):
         | 
| 800 | 
            +
                    if len(text) == 0 or len(text) >= self.max_text_len:
         | 
| 801 | 
            +
                        return None
         | 
| 802 | 
            +
                    if self.lower:
         | 
| 803 | 
            +
                        text = text.lower()
         | 
| 804 | 
            +
                    text_list = []
         | 
| 805 | 
            +
                    for char in text:
         | 
| 806 | 
            +
                        if char not in self.dict:
         | 
| 807 | 
            +
                            text_list.append(self.unknown_idx)
         | 
| 808 | 
            +
                        else:
         | 
| 809 | 
            +
                            text_list.append(self.dict[char])
         | 
| 810 | 
            +
                    text_list.append(self.end_idx)
         | 
| 811 | 
            +
                    if len(text_list) < self.max_text_len:
         | 
| 812 | 
            +
                        text_list += [self.padding_idx] * (self.max_text_len - len(text_list))
         | 
| 813 | 
            +
                    return text_list
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                def __call__(self, data):
         | 
| 816 | 
            +
                    text = data["label"]
         | 
| 817 | 
            +
                    encoded_text = self.encode(text)
         | 
| 818 | 
            +
                    if encoded_text is None:
         | 
| 819 | 
            +
                        return None
         | 
| 820 | 
            +
                    data["label"] = np.array(encoded_text)
         | 
| 821 | 
            +
                    return data
         | 
| 822 | 
            +
             | 
| 823 | 
            +
             | 
| 824 | 
            +
            class VQATokenLabelEncode(object):
         | 
| 825 | 
            +
                """
         | 
| 826 | 
            +
                Label encode for NLP VQA methods
         | 
| 827 | 
            +
                """
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                def __init__(
         | 
| 830 | 
            +
                    self,
         | 
| 831 | 
            +
                    class_path,
         | 
| 832 | 
            +
                    contains_re=False,
         | 
| 833 | 
            +
                    add_special_ids=False,
         | 
| 834 | 
            +
                    algorithm="LayoutXLM",
         | 
| 835 | 
            +
                    infer_mode=False,
         | 
| 836 | 
            +
                    ocr_engine=None,
         | 
| 837 | 
            +
                    **kwargs
         | 
| 838 | 
            +
                ):
         | 
| 839 | 
            +
                    super(VQATokenLabelEncode, self).__init__()
         | 
| 840 | 
            +
                    from paddlenlp.transformers import (
         | 
| 841 | 
            +
                        LayoutLMTokenizer,
         | 
| 842 | 
            +
                        LayoutLMv2Tokenizer,
         | 
| 843 | 
            +
                        LayoutXLMTokenizer,
         | 
| 844 | 
            +
                    )
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    from ppocr.utils.utility import load_vqa_bio_label_maps
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    tokenizer_dict = {
         | 
| 849 | 
            +
                        "LayoutXLM": {
         | 
| 850 | 
            +
                            "class": LayoutXLMTokenizer,
         | 
| 851 | 
            +
                            "pretrained_model": "layoutxlm-base-uncased",
         | 
| 852 | 
            +
                        },
         | 
| 853 | 
            +
                        "LayoutLM": {
         | 
| 854 | 
            +
                            "class": LayoutLMTokenizer,
         | 
| 855 | 
            +
                            "pretrained_model": "layoutlm-base-uncased",
         | 
| 856 | 
            +
                        },
         | 
| 857 | 
            +
                        "LayoutLMv2": {
         | 
| 858 | 
            +
                            "class": LayoutLMv2Tokenizer,
         | 
| 859 | 
            +
                            "pretrained_model": "layoutlmv2-base-uncased",
         | 
| 860 | 
            +
                        },
         | 
| 861 | 
            +
                    }
         | 
| 862 | 
            +
                    self.contains_re = contains_re
         | 
| 863 | 
            +
                    tokenizer_config = tokenizer_dict[algorithm]
         | 
| 864 | 
            +
                    self.tokenizer = tokenizer_config["class"].from_pretrained(
         | 
| 865 | 
            +
                        tokenizer_config["pretrained_model"]
         | 
| 866 | 
            +
                    )
         | 
| 867 | 
            +
                    self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
         | 
| 868 | 
            +
                    self.add_special_ids = add_special_ids
         | 
| 869 | 
            +
                    self.infer_mode = infer_mode
         | 
| 870 | 
            +
                    self.ocr_engine = ocr_engine
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                def __call__(self, data):
         | 
| 873 | 
            +
                    # load bbox and label info
         | 
| 874 | 
            +
                    ocr_info = self._load_ocr_info(data)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    height, width, _ = data["image"].shape
         | 
| 877 | 
            +
             | 
| 878 | 
            +
                    words_list = []
         | 
| 879 | 
            +
                    bbox_list = []
         | 
| 880 | 
            +
                    input_ids_list = []
         | 
| 881 | 
            +
                    token_type_ids_list = []
         | 
| 882 | 
            +
                    segment_offset_id = []
         | 
| 883 | 
            +
                    gt_label_list = []
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                    entities = []
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                    # for re
         | 
| 888 | 
            +
                    train_re = self.contains_re and not self.infer_mode
         | 
| 889 | 
            +
                    if train_re:
         | 
| 890 | 
            +
                        relations = []
         | 
| 891 | 
            +
                        id2label = {}
         | 
| 892 | 
            +
                        entity_id_to_index_map = {}
         | 
| 893 | 
            +
                        empty_entity = set()
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                    data["ocr_info"] = copy.deepcopy(ocr_info)
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                    for info in ocr_info:
         | 
| 898 | 
            +
                        if train_re:
         | 
| 899 | 
            +
                            # for re
         | 
| 900 | 
            +
                            if len(info["text"]) == 0:
         | 
| 901 | 
            +
                                empty_entity.add(info["id"])
         | 
| 902 | 
            +
                                continue
         | 
| 903 | 
            +
                            id2label[info["id"]] = info["label"]
         | 
| 904 | 
            +
                            relations.extend([tuple(sorted(l)) for l in info["linking"]])
         | 
| 905 | 
            +
                        # smooth_box
         | 
| 906 | 
            +
                        bbox = self._smooth_box(info["bbox"], height, width)
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                        text = info["text"]
         | 
| 909 | 
            +
                        encode_res = self.tokenizer.encode(
         | 
| 910 | 
            +
                            text, pad_to_max_seq_len=False, return_attention_mask=True
         | 
| 911 | 
            +
                        )
         | 
| 912 | 
            +
             | 
| 913 | 
            +
                        if not self.add_special_ids:
         | 
| 914 | 
            +
                            # TODO: use tok.all_special_ids to remove
         | 
| 915 | 
            +
                            encode_res["input_ids"] = encode_res["input_ids"][1:-1]
         | 
| 916 | 
            +
                            encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
         | 
| 917 | 
            +
                            encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
         | 
| 918 | 
            +
                        # parse label
         | 
| 919 | 
            +
                        if not self.infer_mode:
         | 
| 920 | 
            +
                            label = info["label"]
         | 
| 921 | 
            +
                            gt_label = self._parse_label(label, encode_res)
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                        # construct entities for re
         | 
| 924 | 
            +
                        if train_re:
         | 
| 925 | 
            +
                            if gt_label[0] != self.label2id_map["O"]:
         | 
| 926 | 
            +
                                entity_id_to_index_map[info["id"]] = len(entities)
         | 
| 927 | 
            +
                                label = label.upper()
         | 
| 928 | 
            +
                                entities.append(
         | 
| 929 | 
            +
                                    {
         | 
| 930 | 
            +
                                        "start": len(input_ids_list),
         | 
| 931 | 
            +
                                        "end": len(input_ids_list) + len(encode_res["input_ids"]),
         | 
| 932 | 
            +
                                        "label": label.upper(),
         | 
| 933 | 
            +
                                    }
         | 
| 934 | 
            +
                                )
         | 
| 935 | 
            +
                        else:
         | 
| 936 | 
            +
                            entities.append(
         | 
| 937 | 
            +
                                {
         | 
| 938 | 
            +
                                    "start": len(input_ids_list),
         | 
| 939 | 
            +
                                    "end": len(input_ids_list) + len(encode_res["input_ids"]),
         | 
| 940 | 
            +
                                    "label": "O",
         | 
| 941 | 
            +
                                }
         | 
| 942 | 
            +
                            )
         | 
| 943 | 
            +
                        input_ids_list.extend(encode_res["input_ids"])
         | 
| 944 | 
            +
                        token_type_ids_list.extend(encode_res["token_type_ids"])
         | 
| 945 | 
            +
                        bbox_list.extend([bbox] * len(encode_res["input_ids"]))
         | 
| 946 | 
            +
                        words_list.append(text)
         | 
| 947 | 
            +
                        segment_offset_id.append(len(input_ids_list))
         | 
| 948 | 
            +
                        if not self.infer_mode:
         | 
| 949 | 
            +
                            gt_label_list.extend(gt_label)
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                    data["input_ids"] = input_ids_list
         | 
| 952 | 
            +
                    data["token_type_ids"] = token_type_ids_list
         | 
| 953 | 
            +
                    data["bbox"] = bbox_list
         | 
| 954 | 
            +
                    data["attention_mask"] = [1] * len(input_ids_list)
         | 
| 955 | 
            +
                    data["labels"] = gt_label_list
         | 
| 956 | 
            +
                    data["segment_offset_id"] = segment_offset_id
         | 
| 957 | 
            +
                    data["tokenizer_params"] = dict(
         | 
| 958 | 
            +
                        padding_side=self.tokenizer.padding_side,
         | 
| 959 | 
            +
                        pad_token_type_id=self.tokenizer.pad_token_type_id,
         | 
| 960 | 
            +
                        pad_token_id=self.tokenizer.pad_token_id,
         | 
| 961 | 
            +
                    )
         | 
| 962 | 
            +
                    data["entities"] = entities
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                    if train_re:
         | 
| 965 | 
            +
                        data["relations"] = relations
         | 
| 966 | 
            +
                        data["id2label"] = id2label
         | 
| 967 | 
            +
                        data["empty_entity"] = empty_entity
         | 
| 968 | 
            +
                        data["entity_id_to_index_map"] = entity_id_to_index_map
         | 
| 969 | 
            +
                    return data
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                def _load_ocr_info(self, data):
         | 
| 972 | 
            +
                    def trans_poly_to_bbox(poly):
         | 
| 973 | 
            +
                        x1 = np.min([p[0] for p in poly])
         | 
| 974 | 
            +
                        x2 = np.max([p[0] for p in poly])
         | 
| 975 | 
            +
                        y1 = np.min([p[1] for p in poly])
         | 
| 976 | 
            +
                        y2 = np.max([p[1] for p in poly])
         | 
| 977 | 
            +
                        return [x1, y1, x2, y2]
         | 
| 978 | 
            +
             | 
| 979 | 
            +
                    if self.infer_mode:
         | 
| 980 | 
            +
                        ocr_result = self.ocr_engine.ocr(data["image"], cls=False)
         | 
| 981 | 
            +
                        ocr_info = []
         | 
| 982 | 
            +
                        for res in ocr_result:
         | 
| 983 | 
            +
                            ocr_info.append(
         | 
| 984 | 
            +
                                {
         | 
| 985 | 
            +
                                    "text": res[1][0],
         | 
| 986 | 
            +
                                    "bbox": trans_poly_to_bbox(res[0]),
         | 
| 987 | 
            +
                                    "poly": res[0],
         | 
| 988 | 
            +
                                }
         | 
| 989 | 
            +
                            )
         | 
| 990 | 
            +
                        return ocr_info
         | 
| 991 | 
            +
                    else:
         | 
| 992 | 
            +
                        info = data["label"]
         | 
| 993 | 
            +
                        # read text info
         | 
| 994 | 
            +
                        info_dict = json.loads(info)
         | 
| 995 | 
            +
                        return info_dict["ocr_info"]
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                def _smooth_box(self, bbox, height, width):
         | 
| 998 | 
            +
                    bbox[0] = int(bbox[0] * 1000.0 / width)
         | 
| 999 | 
            +
                    bbox[2] = int(bbox[2] * 1000.0 / width)
         | 
| 1000 | 
            +
                    bbox[1] = int(bbox[1] * 1000.0 / height)
         | 
| 1001 | 
            +
                    bbox[3] = int(bbox[3] * 1000.0 / height)
         | 
| 1002 | 
            +
                    return bbox
         | 
| 1003 | 
            +
             | 
| 1004 | 
            +
                def _parse_label(self, label, encode_res):
         | 
| 1005 | 
            +
                    gt_label = []
         | 
| 1006 | 
            +
                    if label.lower() == "other":
         | 
| 1007 | 
            +
                        gt_label.extend([0] * len(encode_res["input_ids"]))
         | 
| 1008 | 
            +
                    else:
         | 
| 1009 | 
            +
                        gt_label.append(self.label2id_map[("b-" + label).upper()])
         | 
| 1010 | 
            +
                        gt_label.extend(
         | 
| 1011 | 
            +
                            [self.label2id_map[("i-" + label).upper()]]
         | 
| 1012 | 
            +
                            * (len(encode_res["input_ids"]) - 1)
         | 
| 1013 | 
            +
                        )
         | 
| 1014 | 
            +
                    return gt_label
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
             | 
| 1017 | 
            +
            class MultiLabelEncode(BaseRecLabelEncode):
         | 
| 1018 | 
            +
                def __init__(
         | 
| 1019 | 
            +
                    self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
         | 
| 1020 | 
            +
                ):
         | 
| 1021 | 
            +
                    super(MultiLabelEncode, self).__init__(
         | 
| 1022 | 
            +
                        max_text_length, character_dict_path, use_space_char
         | 
| 1023 | 
            +
                    )
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                    self.ctc_encode = CTCLabelEncode(
         | 
| 1026 | 
            +
                        max_text_length, character_dict_path, use_space_char, **kwargs
         | 
| 1027 | 
            +
                    )
         | 
| 1028 | 
            +
                    self.sar_encode = SARLabelEncode(
         | 
| 1029 | 
            +
                        max_text_length, character_dict_path, use_space_char, **kwargs
         | 
| 1030 | 
            +
                    )
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                def __call__(self, data):
         | 
| 1033 | 
            +
             | 
| 1034 | 
            +
                    data_ctc = copy.deepcopy(data)
         | 
| 1035 | 
            +
                    data_sar = copy.deepcopy(data)
         | 
| 1036 | 
            +
                    data_out = dict()
         | 
| 1037 | 
            +
                    data_out["img_path"] = data.get("img_path", None)
         | 
| 1038 | 
            +
                    data_out["image"] = data["image"]
         | 
| 1039 | 
            +
                    ctc = self.ctc_encode.__call__(data_ctc)
         | 
| 1040 | 
            +
                    sar = self.sar_encode.__call__(data_sar)
         | 
| 1041 | 
            +
                    if ctc is None or sar is None:
         | 
| 1042 | 
            +
                        return None
         | 
| 1043 | 
            +
                    data_out["label_ctc"] = ctc["label"]
         | 
| 1044 | 
            +
                    data_out["label_sar"] = sar["label"]
         | 
| 1045 | 
            +
                    data_out["length"] = ctc["length"]
         | 
| 1046 | 
            +
                    return data_out
         | 
    	
        ocr/ppocr/data/imaug/make_border_map.py
    ADDED
    
    | @@ -0,0 +1,155 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            np.seterr(divide="ignore", invalid="ignore")
         | 
| 7 | 
            +
            import warnings
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import pyclipper
         | 
| 10 | 
            +
            from shapely.geometry import Polygon
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            warnings.simplefilter("ignore")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            __all__ = ["MakeBorderMap"]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class MakeBorderMap(object):
         | 
| 18 | 
            +
                def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs):
         | 
| 19 | 
            +
                    self.shrink_ratio = shrink_ratio
         | 
| 20 | 
            +
                    self.thresh_min = thresh_min
         | 
| 21 | 
            +
                    self.thresh_max = thresh_max
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __call__(self, data):
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    img = data["image"]
         | 
| 26 | 
            +
                    text_polys = data["polys"]
         | 
| 27 | 
            +
                    ignore_tags = data["ignore_tags"]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    canvas = np.zeros(img.shape[:2], dtype=np.float32)
         | 
| 30 | 
            +
                    mask = np.zeros(img.shape[:2], dtype=np.float32)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    for i in range(len(text_polys)):
         | 
| 33 | 
            +
                        if ignore_tags[i]:
         | 
| 34 | 
            +
                            continue
         | 
| 35 | 
            +
                        self.draw_border_map(text_polys[i], canvas, mask=mask)
         | 
| 36 | 
            +
                    canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    data["threshold_map"] = canvas
         | 
| 39 | 
            +
                    data["threshold_mask"] = mask
         | 
| 40 | 
            +
                    return data
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def draw_border_map(self, polygon, canvas, mask):
         | 
| 43 | 
            +
                    polygon = np.array(polygon)
         | 
| 44 | 
            +
                    assert polygon.ndim == 2
         | 
| 45 | 
            +
                    assert polygon.shape[1] == 2
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    polygon_shape = Polygon(polygon)
         | 
| 48 | 
            +
                    if polygon_shape.area <= 0:
         | 
| 49 | 
            +
                        return
         | 
| 50 | 
            +
                    distance = (
         | 
| 51 | 
            +
                        polygon_shape.area
         | 
| 52 | 
            +
                        * (1 - np.power(self.shrink_ratio, 2))
         | 
| 53 | 
            +
                        / polygon_shape.length
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    subject = [tuple(l) for l in polygon]
         | 
| 56 | 
            +
                    padding = pyclipper.PyclipperOffset()
         | 
| 57 | 
            +
                    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    padded_polygon = np.array(padding.Execute(distance)[0])
         | 
| 60 | 
            +
                    cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    xmin = padded_polygon[:, 0].min()
         | 
| 63 | 
            +
                    xmax = padded_polygon[:, 0].max()
         | 
| 64 | 
            +
                    ymin = padded_polygon[:, 1].min()
         | 
| 65 | 
            +
                    ymax = padded_polygon[:, 1].max()
         | 
| 66 | 
            +
                    width = xmax - xmin + 1
         | 
| 67 | 
            +
                    height = ymax - ymin + 1
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    polygon[:, 0] = polygon[:, 0] - xmin
         | 
| 70 | 
            +
                    polygon[:, 1] = polygon[:, 1] - ymin
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    xs = np.broadcast_to(
         | 
| 73 | 
            +
                        np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    ys = np.broadcast_to(
         | 
| 76 | 
            +
                        np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
         | 
| 80 | 
            +
                    for i in range(polygon.shape[0]):
         | 
| 81 | 
            +
                        j = (i + 1) % polygon.shape[0]
         | 
| 82 | 
            +
                        absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
         | 
| 83 | 
            +
                        distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
         | 
| 84 | 
            +
                    distance_map = distance_map.min(axis=0)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
         | 
| 87 | 
            +
                    xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
         | 
| 88 | 
            +
                    ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
         | 
| 89 | 
            +
                    ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
         | 
| 90 | 
            +
                    canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
         | 
| 91 | 
            +
                        1
         | 
| 92 | 
            +
                        - distance_map[
         | 
| 93 | 
            +
                            ymin_valid - ymin : ymax_valid - ymax + height,
         | 
| 94 | 
            +
                            xmin_valid - xmin : xmax_valid - xmax + width,
         | 
| 95 | 
            +
                        ],
         | 
| 96 | 
            +
                        canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def _distance(self, xs, ys, point_1, point_2):
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    compute the distance from point to a line
         | 
| 102 | 
            +
                    ys: coordinates in the first axis
         | 
| 103 | 
            +
                    xs: coordinates in the second axis
         | 
| 104 | 
            +
                    point_1, point_2: (x, y), the end of the line
         | 
| 105 | 
            +
                    """
         | 
| 106 | 
            +
                    height, width = xs.shape[:2]
         | 
| 107 | 
            +
                    square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
         | 
| 108 | 
            +
                    square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
         | 
| 109 | 
            +
                    square_distance = np.square(point_1[0] - point_2[0]) + np.square(
         | 
| 110 | 
            +
                        point_1[1] - point_2[1]
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    cosin = (square_distance - square_distance_1 - square_distance_2) / (
         | 
| 114 | 
            +
                        2 * np.sqrt(square_distance_1 * square_distance_2)
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    square_sin = 1 - np.square(cosin)
         | 
| 117 | 
            +
                    square_sin = np.nan_to_num(square_sin)
         | 
| 118 | 
            +
                    result = np.sqrt(
         | 
| 119 | 
            +
                        square_distance_1 * square_distance_2 * square_sin / square_distance
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
         | 
| 123 | 
            +
                        cosin < 0
         | 
| 124 | 
            +
                    ]
         | 
| 125 | 
            +
                    # self.extend_line(point_1, point_2, result)
         | 
| 126 | 
            +
                    return result
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def extend_line(self, point_1, point_2, result, shrink_ratio):
         | 
| 129 | 
            +
                    ex_point_1 = (
         | 
| 130 | 
            +
                        int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
         | 
| 131 | 
            +
                        int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))),
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
                    cv2.line(
         | 
| 134 | 
            +
                        result,
         | 
| 135 | 
            +
                        tuple(ex_point_1),
         | 
| 136 | 
            +
                        tuple(point_1),
         | 
| 137 | 
            +
                        4096.0,
         | 
| 138 | 
            +
                        1,
         | 
| 139 | 
            +
                        lineType=cv2.LINE_AA,
         | 
| 140 | 
            +
                        shift=0,
         | 
| 141 | 
            +
                    )
         | 
| 142 | 
            +
                    ex_point_2 = (
         | 
| 143 | 
            +
                        int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
         | 
| 144 | 
            +
                        int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))),
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
                    cv2.line(
         | 
| 147 | 
            +
                        result,
         | 
| 148 | 
            +
                        tuple(ex_point_2),
         | 
| 149 | 
            +
                        tuple(point_2),
         | 
| 150 | 
            +
                        4096.0,
         | 
| 151 | 
            +
                        1,
         | 
| 152 | 
            +
                        lineType=cv2.LINE_AA,
         | 
| 153 | 
            +
                        shift=0,
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
                    return ex_point_1, ex_point_2
         | 
    	
        ocr/ppocr/data/imaug/make_pse_gt.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import pyclipper
         | 
| 6 | 
            +
            from shapely.geometry import Polygon
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            __all__ = ["MakePseGt"]
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class MakePseGt(object):
         | 
| 12 | 
            +
                def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
         | 
| 13 | 
            +
                    self.kernel_num = kernel_num
         | 
| 14 | 
            +
                    self.min_shrink_ratio = min_shrink_ratio
         | 
| 15 | 
            +
                    self.size = size
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __call__(self, data):
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    image = data["image"]
         | 
| 20 | 
            +
                    text_polys = data["polys"]
         | 
| 21 | 
            +
                    ignore_tags = data["ignore_tags"]
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    h, w, _ = image.shape
         | 
| 24 | 
            +
                    short_edge = min(h, w)
         | 
| 25 | 
            +
                    if short_edge < self.size:
         | 
| 26 | 
            +
                        # keep short_size >= self.size
         | 
| 27 | 
            +
                        scale = self.size / short_edge
         | 
| 28 | 
            +
                        image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
         | 
| 29 | 
            +
                        text_polys *= scale
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    gt_kernels = []
         | 
| 32 | 
            +
                    for i in range(1, self.kernel_num + 1):
         | 
| 33 | 
            +
                        # s1->sn, from big to small
         | 
| 34 | 
            +
                        rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
         | 
| 35 | 
            +
                        text_kernel, ignore_tags = self.generate_kernel(
         | 
| 36 | 
            +
                            image.shape[0:2], rate, text_polys, ignore_tags
         | 
| 37 | 
            +
                        )
         | 
| 38 | 
            +
                        gt_kernels.append(text_kernel)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    training_mask = np.ones(image.shape[0:2], dtype="uint8")
         | 
| 41 | 
            +
                    for i in range(text_polys.shape[0]):
         | 
| 42 | 
            +
                        if ignore_tags[i]:
         | 
| 43 | 
            +
                            cv2.fillPoly(
         | 
| 44 | 
            +
                                training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0
         | 
| 45 | 
            +
                            )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    gt_kernels = np.array(gt_kernels)
         | 
| 48 | 
            +
                    gt_kernels[gt_kernels > 0] = 1
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    data["image"] = image
         | 
| 51 | 
            +
                    data["polys"] = text_polys
         | 
| 52 | 
            +
                    data["gt_kernels"] = gt_kernels[0:]
         | 
| 53 | 
            +
                    data["gt_text"] = gt_kernels[0]
         | 
| 54 | 
            +
                    data["mask"] = training_mask.astype("float32")
         | 
| 55 | 
            +
                    return data
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
         | 
| 58 | 
            +
                    """
         | 
| 59 | 
            +
                    Refer to part of the code:
         | 
| 60 | 
            +
                    https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    h, w = img_size
         | 
| 64 | 
            +
                    text_kernel = np.zeros((h, w), dtype=np.float32)
         | 
| 65 | 
            +
                    for i, poly in enumerate(text_polys):
         | 
| 66 | 
            +
                        polygon = Polygon(poly)
         | 
| 67 | 
            +
                        distance = (
         | 
| 68 | 
            +
                            polygon.area
         | 
| 69 | 
            +
                            * (1 - shrink_ratio * shrink_ratio)
         | 
| 70 | 
            +
                            / (polygon.length + 1e-6)
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                        subject = [tuple(l) for l in poly]
         | 
| 73 | 
            +
                        pco = pyclipper.PyclipperOffset()
         | 
| 74 | 
            +
                        pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         | 
| 75 | 
            +
                        shrinked = np.array(pco.Execute(-distance))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        if len(shrinked) == 0 or shrinked.size == 0:
         | 
| 78 | 
            +
                            if ignore_tags is not None:
         | 
| 79 | 
            +
                                ignore_tags[i] = True
         | 
| 80 | 
            +
                            continue
         | 
| 81 | 
            +
                        try:
         | 
| 82 | 
            +
                            shrinked = np.array(shrinked[0]).reshape(-1, 2)
         | 
| 83 | 
            +
                        except:
         | 
| 84 | 
            +
                            if ignore_tags is not None:
         | 
| 85 | 
            +
                                ignore_tags[i] = True
         | 
| 86 | 
            +
                            continue
         | 
| 87 | 
            +
                        cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
         | 
| 88 | 
            +
                    return text_kernel, ignore_tags
         | 
    	
        ocr/ppocr/data/imaug/make_shrink_map.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import pyclipper
         | 
| 6 | 
            +
            from shapely.geometry import Polygon
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            __all__ = ["MakeShrinkMap"]
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class MakeShrinkMap(object):
         | 
| 12 | 
            +
                r"""
         | 
| 13 | 
            +
                Making binary mask from detection data with ICDAR format.
         | 
| 14 | 
            +
                Typically following the process of class `MakeICDARData`.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
         | 
| 18 | 
            +
                    self.min_text_size = min_text_size
         | 
| 19 | 
            +
                    self.shrink_ratio = shrink_ratio
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __call__(self, data):
         | 
| 22 | 
            +
                    image = data["image"]
         | 
| 23 | 
            +
                    text_polys = data["polys"]
         | 
| 24 | 
            +
                    ignore_tags = data["ignore_tags"]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    h, w = image.shape[:2]
         | 
| 27 | 
            +
                    text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
         | 
| 28 | 
            +
                    gt = np.zeros((h, w), dtype=np.float32)
         | 
| 29 | 
            +
                    mask = np.ones((h, w), dtype=np.float32)
         | 
| 30 | 
            +
                    for i in range(len(text_polys)):
         | 
| 31 | 
            +
                        polygon = text_polys[i]
         | 
| 32 | 
            +
                        height = max(polygon[:, 1]) - min(polygon[:, 1])
         | 
| 33 | 
            +
                        width = max(polygon[:, 0]) - min(polygon[:, 0])
         | 
| 34 | 
            +
                        if ignore_tags[i] or min(height, width) < self.min_text_size:
         | 
| 35 | 
            +
                            cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
         | 
| 36 | 
            +
                            ignore_tags[i] = True
         | 
| 37 | 
            +
                        else:
         | 
| 38 | 
            +
                            polygon_shape = Polygon(polygon)
         | 
| 39 | 
            +
                            subject = [tuple(l) for l in polygon]
         | 
| 40 | 
            +
                            padding = pyclipper.PyclipperOffset()
         | 
| 41 | 
            +
                            padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         | 
| 42 | 
            +
                            shrinked = []
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                            # Increase the shrink ratio every time we get multiple polygon returned back
         | 
| 45 | 
            +
                            possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio)
         | 
| 46 | 
            +
                            np.append(possible_ratios, 1)
         | 
| 47 | 
            +
                            # print(possible_ratios)
         | 
| 48 | 
            +
                            for ratio in possible_ratios:
         | 
| 49 | 
            +
                                # print(f"Change shrink ratio to {ratio}")
         | 
| 50 | 
            +
                                distance = (
         | 
| 51 | 
            +
                                    polygon_shape.area
         | 
| 52 | 
            +
                                    * (1 - np.power(ratio, 2))
         | 
| 53 | 
            +
                                    / polygon_shape.length
         | 
| 54 | 
            +
                                )
         | 
| 55 | 
            +
                                shrinked = padding.Execute(-distance)
         | 
| 56 | 
            +
                                if len(shrinked) == 1:
         | 
| 57 | 
            +
                                    break
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                            if shrinked == []:
         | 
| 60 | 
            +
                                cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
         | 
| 61 | 
            +
                                ignore_tags[i] = True
         | 
| 62 | 
            +
                                continue
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                            for each_shirnk in shrinked:
         | 
| 65 | 
            +
                                shirnk = np.array(each_shirnk).reshape(-1, 2)
         | 
| 66 | 
            +
                                cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    data["shrink_map"] = gt
         | 
| 69 | 
            +
                    data["shrink_mask"] = mask
         | 
| 70 | 
            +
                    return data
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def validate_polygons(self, polygons, ignore_tags, h, w):
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    polygons (numpy.array, required): of shape (num_instances, num_points, 2)
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    if len(polygons) == 0:
         | 
| 77 | 
            +
                        return polygons, ignore_tags
         | 
| 78 | 
            +
                    assert len(polygons) == len(ignore_tags)
         | 
| 79 | 
            +
                    for polygon in polygons:
         | 
| 80 | 
            +
                        polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
         | 
| 81 | 
            +
                        polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    for i in range(len(polygons)):
         | 
| 84 | 
            +
                        area = self.polygon_area(polygons[i])
         | 
| 85 | 
            +
                        if abs(area) < 1:
         | 
| 86 | 
            +
                            ignore_tags[i] = True
         | 
| 87 | 
            +
                        if area > 0:
         | 
| 88 | 
            +
                            polygons[i] = polygons[i][::-1, :]
         | 
| 89 | 
            +
                    return polygons, ignore_tags
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def polygon_area(self, polygon):
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                    compute polygon area
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    area = 0
         | 
| 96 | 
            +
                    q = polygon[-1]
         | 
| 97 | 
            +
                    for p in polygon:
         | 
| 98 | 
            +
                        area += p[0] * q[1] - p[1] * q[0]
         | 
| 99 | 
            +
                        q = p
         | 
| 100 | 
            +
                    return area / 2.0
         | 
    	
        ocr/ppocr/data/imaug/operators.py
    ADDED
    
    | @@ -0,0 +1,458 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function, unicode_literals
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import cv2
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import six
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class DecodeImage(object):
         | 
| 12 | 
            +
                """decode image"""
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(
         | 
| 15 | 
            +
                    self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs
         | 
| 16 | 
            +
                ):
         | 
| 17 | 
            +
                    self.img_mode = img_mode
         | 
| 18 | 
            +
                    self.channel_first = channel_first
         | 
| 19 | 
            +
                    self.ignore_orientation = ignore_orientation
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __call__(self, data):
         | 
| 22 | 
            +
                    img = data["image"]
         | 
| 23 | 
            +
                    if six.PY2:
         | 
| 24 | 
            +
                        assert (
         | 
| 25 | 
            +
                            type(img) is str and len(img) > 0
         | 
| 26 | 
            +
                        ), "invalid input 'img' in DecodeImage"
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        assert (
         | 
| 29 | 
            +
                            type(img) is bytes and len(img) > 0
         | 
| 30 | 
            +
                        ), "invalid input 'img' in DecodeImage"
         | 
| 31 | 
            +
                    img = np.frombuffer(img, dtype="uint8")
         | 
| 32 | 
            +
                    if self.ignore_orientation:
         | 
| 33 | 
            +
                        img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
         | 
| 34 | 
            +
                    else:
         | 
| 35 | 
            +
                        img = cv2.imdecode(img, 1)
         | 
| 36 | 
            +
                    if img is None:
         | 
| 37 | 
            +
                        return None
         | 
| 38 | 
            +
                    if self.img_mode == "GRAY":
         | 
| 39 | 
            +
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         | 
| 40 | 
            +
                    elif self.img_mode == "RGB":
         | 
| 41 | 
            +
                        assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
         | 
| 42 | 
            +
                        img = img[:, :, ::-1]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if self.channel_first:
         | 
| 45 | 
            +
                        img = img.transpose((2, 0, 1))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    data["image"] = img
         | 
| 48 | 
            +
                    return data
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class NRTRDecodeImage(object):
         | 
| 52 | 
            +
                """decode image"""
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __init__(self, img_mode="RGB", channel_first=False, **kwargs):
         | 
| 55 | 
            +
                    self.img_mode = img_mode
         | 
| 56 | 
            +
                    self.channel_first = channel_first
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __call__(self, data):
         | 
| 59 | 
            +
                    img = data["image"]
         | 
| 60 | 
            +
                    if six.PY2:
         | 
| 61 | 
            +
                        assert (
         | 
| 62 | 
            +
                            type(img) is str and len(img) > 0
         | 
| 63 | 
            +
                        ), "invalid input 'img' in DecodeImage"
         | 
| 64 | 
            +
                    else:
         | 
| 65 | 
            +
                        assert (
         | 
| 66 | 
            +
                            type(img) is bytes and len(img) > 0
         | 
| 67 | 
            +
                        ), "invalid input 'img' in DecodeImage"
         | 
| 68 | 
            +
                    img = np.frombuffer(img, dtype="uint8")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    img = cv2.imdecode(img, 1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    if img is None:
         | 
| 73 | 
            +
                        return None
         | 
| 74 | 
            +
                    if self.img_mode == "GRAY":
         | 
| 75 | 
            +
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         | 
| 76 | 
            +
                    elif self.img_mode == "RGB":
         | 
| 77 | 
            +
                        assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
         | 
| 78 | 
            +
                        img = img[:, :, ::-1]
         | 
| 79 | 
            +
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
         | 
| 80 | 
            +
                    if self.channel_first:
         | 
| 81 | 
            +
                        img = img.transpose((2, 0, 1))
         | 
| 82 | 
            +
                    data["image"] = img
         | 
| 83 | 
            +
                    return data
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class NormalizeImage(object):
         | 
| 87 | 
            +
                """normalize image such as substract mean, divide std"""
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
         | 
| 90 | 
            +
                    if isinstance(scale, str):
         | 
| 91 | 
            +
                        scale = eval(scale)
         | 
| 92 | 
            +
                    self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
         | 
| 93 | 
            +
                    mean = mean if mean is not None else [0.485, 0.456, 0.406]
         | 
| 94 | 
            +
                    std = std if std is not None else [0.229, 0.224, 0.225]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
         | 
| 97 | 
            +
                    self.mean = np.array(mean).reshape(shape).astype("float32")
         | 
| 98 | 
            +
                    self.std = np.array(std).reshape(shape).astype("float32")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def __call__(self, data):
         | 
| 101 | 
            +
                    img = data["image"]
         | 
| 102 | 
            +
                    from PIL import Image
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if isinstance(img, Image.Image):
         | 
| 105 | 
            +
                        img = np.array(img)
         | 
| 106 | 
            +
                    assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
         | 
| 107 | 
            +
                    data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
         | 
| 108 | 
            +
                    return data
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            class ToCHWImage(object):
         | 
| 112 | 
            +
                """convert hwc image to chw image"""
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __init__(self, **kwargs):
         | 
| 115 | 
            +
                    pass
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def __call__(self, data):
         | 
| 118 | 
            +
                    img = data["image"]
         | 
| 119 | 
            +
                    from PIL import Image
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    if isinstance(img, Image.Image):
         | 
| 122 | 
            +
                        img = np.array(img)
         | 
| 123 | 
            +
                    data["image"] = img.transpose((2, 0, 1))
         | 
| 124 | 
            +
                    return data
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            class Fasttext(object):
         | 
| 128 | 
            +
                def __init__(self, path="None", **kwargs):
         | 
| 129 | 
            +
                    import fasttext
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.fast_model = fasttext.load_model(path)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def __call__(self, data):
         | 
| 134 | 
            +
                    label = data["label"]
         | 
| 135 | 
            +
                    fast_label = self.fast_model[label]
         | 
| 136 | 
            +
                    data["fast_label"] = fast_label
         | 
| 137 | 
            +
                    return data
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            class KeepKeys(object):
         | 
| 141 | 
            +
                def __init__(self, keep_keys, **kwargs):
         | 
| 142 | 
            +
                    self.keep_keys = keep_keys
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def __call__(self, data):
         | 
| 145 | 
            +
                    data_list = []
         | 
| 146 | 
            +
                    for key in self.keep_keys:
         | 
| 147 | 
            +
                        data_list.append(data[key])
         | 
| 148 | 
            +
                    return data_list
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            class Pad(object):
         | 
| 152 | 
            +
                def __init__(self, size=None, size_div=32, **kwargs):
         | 
| 153 | 
            +
                    if size is not None and not isinstance(size, (int, list, tuple)):
         | 
| 154 | 
            +
                        raise TypeError(
         | 
| 155 | 
            +
                            "Type of target_size is invalid. Now is {}".format(type(size))
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
                    if isinstance(size, int):
         | 
| 158 | 
            +
                        size = [size, size]
         | 
| 159 | 
            +
                    self.size = size
         | 
| 160 | 
            +
                    self.size_div = size_div
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def __call__(self, data):
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    img = data["image"]
         | 
| 165 | 
            +
                    img_h, img_w = img.shape[0], img.shape[1]
         | 
| 166 | 
            +
                    if self.size:
         | 
| 167 | 
            +
                        resize_h2, resize_w2 = self.size
         | 
| 168 | 
            +
                        assert (
         | 
| 169 | 
            +
                            img_h < resize_h2 and img_w < resize_w2
         | 
| 170 | 
            +
                        ), "(h, w) of target size should be greater than (img_h, img_w)"
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        resize_h2 = max(
         | 
| 173 | 
            +
                            int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
         | 
| 174 | 
            +
                            self.size_div,
         | 
| 175 | 
            +
                        )
         | 
| 176 | 
            +
                        resize_w2 = max(
         | 
| 177 | 
            +
                            int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
         | 
| 178 | 
            +
                            self.size_div,
         | 
| 179 | 
            +
                        )
         | 
| 180 | 
            +
                    img = cv2.copyMakeBorder(
         | 
| 181 | 
            +
                        img,
         | 
| 182 | 
            +
                        0,
         | 
| 183 | 
            +
                        resize_h2 - img_h,
         | 
| 184 | 
            +
                        0,
         | 
| 185 | 
            +
                        resize_w2 - img_w,
         | 
| 186 | 
            +
                        cv2.BORDER_CONSTANT,
         | 
| 187 | 
            +
                        value=0,
         | 
| 188 | 
            +
                    )
         | 
| 189 | 
            +
                    data["image"] = img
         | 
| 190 | 
            +
                    return data
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            class Resize(object):
         | 
| 194 | 
            +
                def __init__(self, size=(640, 640), **kwargs):
         | 
| 195 | 
            +
                    self.size = size
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def resize_image(self, img):
         | 
| 198 | 
            +
                    resize_h, resize_w = self.size
         | 
| 199 | 
            +
                    ori_h, ori_w = img.shape[:2]  # (h, w, c)
         | 
| 200 | 
            +
                    ratio_h = float(resize_h) / ori_h
         | 
| 201 | 
            +
                    ratio_w = float(resize_w) / ori_w
         | 
| 202 | 
            +
                    img = cv2.resize(img, (int(resize_w), int(resize_h)))
         | 
| 203 | 
            +
                    return img, [ratio_h, ratio_w]
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def __call__(self, data):
         | 
| 206 | 
            +
                    img = data["image"]
         | 
| 207 | 
            +
                    if "polys" in data:
         | 
| 208 | 
            +
                        text_polys = data["polys"]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    img_resize, [ratio_h, ratio_w] = self.resize_image(img)
         | 
| 211 | 
            +
                    if "polys" in data:
         | 
| 212 | 
            +
                        new_boxes = []
         | 
| 213 | 
            +
                        for box in text_polys:
         | 
| 214 | 
            +
                            new_box = []
         | 
| 215 | 
            +
                            for cord in box:
         | 
| 216 | 
            +
                                new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
         | 
| 217 | 
            +
                            new_boxes.append(new_box)
         | 
| 218 | 
            +
                        data["polys"] = np.array(new_boxes, dtype=np.float32)
         | 
| 219 | 
            +
                    data["image"] = img_resize
         | 
| 220 | 
            +
                    return data
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            class DetResizeForTest(object):
         | 
| 224 | 
            +
                def __init__(self, **kwargs):
         | 
| 225 | 
            +
                    super(DetResizeForTest, self).__init__()
         | 
| 226 | 
            +
                    self.resize_type = 0
         | 
| 227 | 
            +
                    if "image_shape" in kwargs:
         | 
| 228 | 
            +
                        self.image_shape = kwargs["image_shape"]
         | 
| 229 | 
            +
                        self.resize_type = 1
         | 
| 230 | 
            +
                    elif "limit_side_len" in kwargs:
         | 
| 231 | 
            +
                        self.limit_side_len = kwargs["limit_side_len"]
         | 
| 232 | 
            +
                        self.limit_type = kwargs.get("limit_type", "min")
         | 
| 233 | 
            +
                    elif "resize_long" in kwargs:
         | 
| 234 | 
            +
                        self.resize_type = 2
         | 
| 235 | 
            +
                        self.resize_long = kwargs.get("resize_long", 960)
         | 
| 236 | 
            +
                    else:
         | 
| 237 | 
            +
                        self.limit_side_len = 736
         | 
| 238 | 
            +
                        self.limit_type = "min"
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def __call__(self, data):
         | 
| 241 | 
            +
                    img = data["image"]
         | 
| 242 | 
            +
                    src_h, src_w, _ = img.shape
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if self.resize_type == 0:
         | 
| 245 | 
            +
                        # img, shape = self.resize_image_type0(img)
         | 
| 246 | 
            +
                        img, [ratio_h, ratio_w] = self.resize_image_type0(img)
         | 
| 247 | 
            +
                    elif self.resize_type == 2:
         | 
| 248 | 
            +
                        img, [ratio_h, ratio_w] = self.resize_image_type2(img)
         | 
| 249 | 
            +
                    else:
         | 
| 250 | 
            +
                        # img, shape = self.resize_image_type1(img)
         | 
| 251 | 
            +
                        img, [ratio_h, ratio_w] = self.resize_image_type1(img)
         | 
| 252 | 
            +
                    data["image"] = img
         | 
| 253 | 
            +
                    data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
         | 
| 254 | 
            +
                    return data
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def resize_image_type1(self, img):
         | 
| 257 | 
            +
                    resize_h, resize_w = self.image_shape
         | 
| 258 | 
            +
                    ori_h, ori_w = img.shape[:2]  # (h, w, c)
         | 
| 259 | 
            +
                    ratio_h = float(resize_h) / ori_h
         | 
| 260 | 
            +
                    ratio_w = float(resize_w) / ori_w
         | 
| 261 | 
            +
                    img = cv2.resize(img, (int(resize_w), int(resize_h)))
         | 
| 262 | 
            +
                    # return img, np.array([ori_h, ori_w])
         | 
| 263 | 
            +
                    return img, [ratio_h, ratio_w]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def resize_image_type0(self, img):
         | 
| 266 | 
            +
                    """
         | 
| 267 | 
            +
                    resize image to a size multiple of 32 which is required by the network
         | 
| 268 | 
            +
                    args:
         | 
| 269 | 
            +
                        img(array): array with shape [h, w, c]
         | 
| 270 | 
            +
                    return(tuple):
         | 
| 271 | 
            +
                        img, (ratio_h, ratio_w)
         | 
| 272 | 
            +
                    """
         | 
| 273 | 
            +
                    limit_side_len = self.limit_side_len
         | 
| 274 | 
            +
                    h, w, c = img.shape
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # limit the max side
         | 
| 277 | 
            +
                    if self.limit_type == "max":
         | 
| 278 | 
            +
                        if max(h, w) > limit_side_len:
         | 
| 279 | 
            +
                            if h > w:
         | 
| 280 | 
            +
                                ratio = float(limit_side_len) / h
         | 
| 281 | 
            +
                            else:
         | 
| 282 | 
            +
                                ratio = float(limit_side_len) / w
         | 
| 283 | 
            +
                        else:
         | 
| 284 | 
            +
                            ratio = 1.0
         | 
| 285 | 
            +
                    elif self.limit_type == "min":
         | 
| 286 | 
            +
                        if min(h, w) < limit_side_len:
         | 
| 287 | 
            +
                            if h < w:
         | 
| 288 | 
            +
                                ratio = float(limit_side_len) / h
         | 
| 289 | 
            +
                            else:
         | 
| 290 | 
            +
                                ratio = float(limit_side_len) / w
         | 
| 291 | 
            +
                        else:
         | 
| 292 | 
            +
                            ratio = 1.0
         | 
| 293 | 
            +
                    elif self.limit_type == "resize_long":
         | 
| 294 | 
            +
                        ratio = float(limit_side_len) / max(h, w)
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        raise Exception("not support limit type, image ")
         | 
| 297 | 
            +
                    resize_h = int(h * ratio)
         | 
| 298 | 
            +
                    resize_w = int(w * ratio)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    resize_h = max(int(round(resize_h / 32) * 32), 32)
         | 
| 301 | 
            +
                    resize_w = max(int(round(resize_w / 32) * 32), 32)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    try:
         | 
| 304 | 
            +
                        if int(resize_w) <= 0 or int(resize_h) <= 0:
         | 
| 305 | 
            +
                            return None, (None, None)
         | 
| 306 | 
            +
                        img = cv2.resize(img, (int(resize_w), int(resize_h)))
         | 
| 307 | 
            +
                    except:
         | 
| 308 | 
            +
                        print(img.shape, resize_w, resize_h)
         | 
| 309 | 
            +
                        sys.exit(0)
         | 
| 310 | 
            +
                    ratio_h = resize_h / float(h)
         | 
| 311 | 
            +
                    ratio_w = resize_w / float(w)
         | 
| 312 | 
            +
                    return img, [ratio_h, ratio_w]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def resize_image_type2(self, img):
         | 
| 315 | 
            +
                    h, w, _ = img.shape
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    resize_w = w
         | 
| 318 | 
            +
                    resize_h = h
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    if resize_h > resize_w:
         | 
| 321 | 
            +
                        ratio = float(self.resize_long) / resize_h
         | 
| 322 | 
            +
                    else:
         | 
| 323 | 
            +
                        ratio = float(self.resize_long) / resize_w
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    resize_h = int(resize_h * ratio)
         | 
| 326 | 
            +
                    resize_w = int(resize_w * ratio)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    max_stride = 128
         | 
| 329 | 
            +
                    resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
         | 
| 330 | 
            +
                    resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
         | 
| 331 | 
            +
                    img = cv2.resize(img, (int(resize_w), int(resize_h)))
         | 
| 332 | 
            +
                    ratio_h = resize_h / float(h)
         | 
| 333 | 
            +
                    ratio_w = resize_w / float(w)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    return img, [ratio_h, ratio_w]
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            class E2EResizeForTest(object):
         | 
| 339 | 
            +
                def __init__(self, **kwargs):
         | 
| 340 | 
            +
                    super(E2EResizeForTest, self).__init__()
         | 
| 341 | 
            +
                    self.max_side_len = kwargs["max_side_len"]
         | 
| 342 | 
            +
                    self.valid_set = kwargs["valid_set"]
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def __call__(self, data):
         | 
| 345 | 
            +
                    img = data["image"]
         | 
| 346 | 
            +
                    src_h, src_w, _ = img.shape
         | 
| 347 | 
            +
                    if self.valid_set == "totaltext":
         | 
| 348 | 
            +
                        im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
         | 
| 349 | 
            +
                            img, max_side_len=self.max_side_len
         | 
| 350 | 
            +
                        )
         | 
| 351 | 
            +
                    else:
         | 
| 352 | 
            +
                        im_resized, (ratio_h, ratio_w) = self.resize_image(
         | 
| 353 | 
            +
                            img, max_side_len=self.max_side_len
         | 
| 354 | 
            +
                        )
         | 
| 355 | 
            +
                    data["image"] = im_resized
         | 
| 356 | 
            +
                    data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
         | 
| 357 | 
            +
                    return data
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                def resize_image_for_totaltext(self, im, max_side_len=512):
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    h, w, _ = im.shape
         | 
| 362 | 
            +
                    resize_w = w
         | 
| 363 | 
            +
                    resize_h = h
         | 
| 364 | 
            +
                    ratio = 1.25
         | 
| 365 | 
            +
                    if h * ratio > max_side_len:
         | 
| 366 | 
            +
                        ratio = float(max_side_len) / resize_h
         | 
| 367 | 
            +
                    resize_h = int(resize_h * ratio)
         | 
| 368 | 
            +
                    resize_w = int(resize_w * ratio)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    max_stride = 128
         | 
| 371 | 
            +
                    resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
         | 
| 372 | 
            +
                    resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
         | 
| 373 | 
            +
                    im = cv2.resize(im, (int(resize_w), int(resize_h)))
         | 
| 374 | 
            +
                    ratio_h = resize_h / float(h)
         | 
| 375 | 
            +
                    ratio_w = resize_w / float(w)
         | 
| 376 | 
            +
                    return im, (ratio_h, ratio_w)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                def resize_image(self, im, max_side_len=512):
         | 
| 379 | 
            +
                    """
         | 
| 380 | 
            +
                    resize image to a size multiple of max_stride which is required by the network
         | 
| 381 | 
            +
                    :param im: the resized image
         | 
| 382 | 
            +
                    :param max_side_len: limit of max image size to avoid out of memory in gpu
         | 
| 383 | 
            +
                    :return: the resized image and the resize ratio
         | 
| 384 | 
            +
                    """
         | 
| 385 | 
            +
                    h, w, _ = im.shape
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    resize_w = w
         | 
| 388 | 
            +
                    resize_h = h
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # Fix the longer side
         | 
| 391 | 
            +
                    if resize_h > resize_w:
         | 
| 392 | 
            +
                        ratio = float(max_side_len) / resize_h
         | 
| 393 | 
            +
                    else:
         | 
| 394 | 
            +
                        ratio = float(max_side_len) / resize_w
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    resize_h = int(resize_h * ratio)
         | 
| 397 | 
            +
                    resize_w = int(resize_w * ratio)
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    max_stride = 128
         | 
| 400 | 
            +
                    resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
         | 
| 401 | 
            +
                    resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
         | 
| 402 | 
            +
                    im = cv2.resize(im, (int(resize_w), int(resize_h)))
         | 
| 403 | 
            +
                    ratio_h = resize_h / float(h)
         | 
| 404 | 
            +
                    ratio_w = resize_w / float(w)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    return im, (ratio_h, ratio_w)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            class KieResize(object):
         | 
| 410 | 
            +
                def __init__(self, **kwargs):
         | 
| 411 | 
            +
                    super(KieResize, self).__init__()
         | 
| 412 | 
            +
                    self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1]
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                def __call__(self, data):
         | 
| 415 | 
            +
                    img = data["image"]
         | 
| 416 | 
            +
                    points = data["points"]
         | 
| 417 | 
            +
                    src_h, src_w, _ = img.shape
         | 
| 418 | 
            +
                    (
         | 
| 419 | 
            +
                        im_resized,
         | 
| 420 | 
            +
                        scale_factor,
         | 
| 421 | 
            +
                        [ratio_h, ratio_w],
         | 
| 422 | 
            +
                        [new_h, new_w],
         | 
| 423 | 
            +
                    ) = self.resize_image(img)
         | 
| 424 | 
            +
                    resize_points = self.resize_boxes(img, points, scale_factor)
         | 
| 425 | 
            +
                    data["ori_image"] = img
         | 
| 426 | 
            +
                    data["ori_boxes"] = points
         | 
| 427 | 
            +
                    data["points"] = resize_points
         | 
| 428 | 
            +
                    data["image"] = im_resized
         | 
| 429 | 
            +
                    data["shape"] = np.array([new_h, new_w])
         | 
| 430 | 
            +
                    return data
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def resize_image(self, img):
         | 
| 433 | 
            +
                    norm_img = np.zeros([1024, 1024, 3], dtype="float32")
         | 
| 434 | 
            +
                    scale = [512, 1024]
         | 
| 435 | 
            +
                    h, w = img.shape[:2]
         | 
| 436 | 
            +
                    max_long_edge = max(scale)
         | 
| 437 | 
            +
                    max_short_edge = min(scale)
         | 
| 438 | 
            +
                    scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
         | 
| 439 | 
            +
                    resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(
         | 
| 440 | 
            +
                        h * float(scale_factor) + 0.5
         | 
| 441 | 
            +
                    )
         | 
| 442 | 
            +
                    max_stride = 32
         | 
| 443 | 
            +
                    resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
         | 
| 444 | 
            +
                    resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
         | 
| 445 | 
            +
                    im = cv2.resize(img, (resize_w, resize_h))
         | 
| 446 | 
            +
                    new_h, new_w = im.shape[:2]
         | 
| 447 | 
            +
                    w_scale = new_w / w
         | 
| 448 | 
            +
                    h_scale = new_h / h
         | 
| 449 | 
            +
                    scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
         | 
| 450 | 
            +
                    norm_img[:new_h, :new_w, :] = im
         | 
| 451 | 
            +
                    return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                def resize_boxes(self, im, points, scale_factor):
         | 
| 454 | 
            +
                    points = points * scale_factor
         | 
| 455 | 
            +
                    img_shape = im.shape[:2]
         | 
| 456 | 
            +
                    points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
         | 
| 457 | 
            +
                    points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
         | 
| 458 | 
            +
                    return points
         | 
