Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		admin
		
	commited on
		
		
					Commit 
							
							·
						
						67a9b5d
	
1
								Parent(s):
							
							89fbbde
								
sync
Browse files- .gitattributes +2 -0
- .gitignore +7 -0
- README.md +8 -7
- app.py +86 -0
- insectid/__init__.py +2 -0
- insectid/base.py +51 -0
- insectid/detector.py +58 -0
- insectid/identifier.py +76 -0
- khandy/__init__.py +18 -0
- khandy/boxes/__init__.py +13 -0
- khandy/boxes/boxes_and_indices.py +68 -0
- khandy/boxes/boxes_clip.py +34 -0
- khandy/boxes/boxes_coder.py +69 -0
- khandy/boxes/boxes_convert.py +101 -0
- khandy/boxes/boxes_filter.py +113 -0
- khandy/boxes/boxes_overlap.py +166 -0
- khandy/boxes/boxes_transform_flip.py +135 -0
- khandy/boxes/boxes_transform_rotate.py +140 -0
- khandy/boxes/boxes_transform_scale.py +86 -0
- khandy/boxes/boxes_transform_translate.py +136 -0
- khandy/boxes/boxes_utils.py +28 -0
- khandy/dict_utils.py +168 -0
- khandy/draw_utils.py +148 -0
- khandy/feature_utils.py +62 -0
- khandy/file_io_utils.py +87 -0
- khandy/fs_utils.py +375 -0
- khandy/hash_utils.py +25 -0
- khandy/image/__init__.py +10 -0
- khandy/image/align_and_crop.py +60 -0
- khandy/image/crop_or_pad.py +138 -0
- khandy/image/flip.py +72 -0
- khandy/image/image_hash.py +69 -0
- khandy/image/misc.py +329 -0
- khandy/image/resize.py +177 -0
- khandy/image/rotate.py +72 -0
- khandy/image/translate.py +57 -0
- khandy/label/__init__.py +2 -0
- khandy/label/detect.py +594 -0
- khandy/list_utils.py +68 -0
- khandy/misc.py +245 -0
- khandy/numpy_utils.py +173 -0
- khandy/points/__init__.py +2 -0
- khandy/points/pts_letterbox.py +19 -0
- khandy/points/pts_transform_scale.py +33 -0
- khandy/split_utils.py +71 -0
- khandy/text_utils.py +33 -0
- khandy/time_utils.py +101 -0
- khandy/version.py +3 -0
- requirements.txt +7 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            images/Coccinella_septempunctata.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            simsun.ttc filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __pycache__/
         | 
| 2 | 
            +
            _local/
         | 
| 3 | 
            +
            *.pyc
         | 
| 4 | 
            +
            local_models_*/
         | 
| 5 | 
            +
            rename.sh
         | 
| 6 | 
            +
            *.onnx
         | 
| 7 | 
            +
            simsun.ttc
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,14 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
            -
            license: mit
         | 
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
            -
             | 
|  | |
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: insecta
         | 
| 3 | 
            +
            emoji: 🐞
         | 
| 4 | 
            +
            colorFrom: indigo
         | 
| 5 | 
            +
            colorTo: pink
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.39.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
|  | |
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
            +
            # 特性
         | 
| 13 | 
            +
            - 支持 2037 类 (可能是目, 科, 属或种等) 昆虫或其他节肢动物
         | 
| 14 | 
            +
            - 模型开源, 持续更新.
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import khandy
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            from modelscope import snapshot_download
         | 
| 7 | 
            +
            from insectid import InsectDetector, InsectIdentifier
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            MODEL_DIR = snapshot_download("MuGeminorum/insecta", cache_dir="./insectid/__pycache__")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def infer(filename: str):
         | 
| 13 | 
            +
                if not filename:
         | 
| 14 | 
            +
                    None, "请上传图片 Please upload a picture"
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                detector = InsectDetector()
         | 
| 17 | 
            +
                identifier = InsectIdentifier()
         | 
| 18 | 
            +
                image = khandy.imread(filename)
         | 
| 19 | 
            +
                if image is None:
         | 
| 20 | 
            +
                    return None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                if max(image.shape[:2]) > 1280:
         | 
| 23 | 
            +
                    image = khandy.resize_image_long(image, 1280)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                image_for_draw = image.copy()
         | 
| 26 | 
            +
                image_height, image_width = image.shape[:2]
         | 
| 27 | 
            +
                boxes, confs, classes = detector.detect(image)
         | 
| 28 | 
            +
                text = "未知"
         | 
| 29 | 
            +
                for box, _, _ in zip(boxes, confs, classes):
         | 
| 30 | 
            +
                    box = box.astype(np.int32)
         | 
| 31 | 
            +
                    box_width = box[2] - box[0] + 1
         | 
| 32 | 
            +
                    box_height = box[3] - box[1] + 1
         | 
| 33 | 
            +
                    if box_width < 30 or box_height < 30:
         | 
| 34 | 
            +
                        continue
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    cropped = khandy.crop_or_pad(image, box[0], box[1], box[2], box[3])
         | 
| 37 | 
            +
                    results = identifier.identify(cropped)
         | 
| 38 | 
            +
                    print(results[0])
         | 
| 39 | 
            +
                    prob = results[0]["probability"]
         | 
| 40 | 
            +
                    if prob >= 0.10:
         | 
| 41 | 
            +
                        text = "{} {}: {:.2f}%".format(
         | 
| 42 | 
            +
                            results[0]["chinese_name"],
         | 
| 43 | 
            +
                            results[0]["latin_name"],
         | 
| 44 | 
            +
                            100.0 * results[0]["probability"],
         | 
| 45 | 
            +
                        )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    position = [box[0] + 2, box[1] - 20]
         | 
| 48 | 
            +
                    position[0] = min(max(position[0], 0), image_width)
         | 
| 49 | 
            +
                    position[1] = min(max(position[1], 0), image_height)
         | 
| 50 | 
            +
                    cv2.rectangle(
         | 
| 51 | 
            +
                        image_for_draw,
         | 
| 52 | 
            +
                        (box[0], box[1]),
         | 
| 53 | 
            +
                        (box[2], box[3]),
         | 
| 54 | 
            +
                        (0, 255, 0),
         | 
| 55 | 
            +
                        2,
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
                    image_for_draw = khandy.draw_text(
         | 
| 58 | 
            +
                        image_for_draw,
         | 
| 59 | 
            +
                        text,
         | 
| 60 | 
            +
                        position,
         | 
| 61 | 
            +
                        font=f"{MODEL_DIR}/simsun.ttc",
         | 
| 62 | 
            +
                        font_size=15,
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                outxt = text.split(":")[0] if ":" in text else text
         | 
| 66 | 
            +
                return Image.fromarray(image_for_draw[:, :, ::-1], mode="RGB"), outxt
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            if __name__ == "__main__":
         | 
| 70 | 
            +
                iface = gr.Interface(
         | 
| 71 | 
            +
                    fn=infer,
         | 
| 72 | 
            +
                    inputs=gr.Image(label="上传昆虫照片 Upload insect picture", type="filepath"),
         | 
| 73 | 
            +
                    outputs=[
         | 
| 74 | 
            +
                        gr.Image(label="识别结果 Recognition result"),
         | 
| 75 | 
            +
                        gr.Textbox(label="最可能的物种 Best match", show_copy_button=True),
         | 
| 76 | 
            +
                    ],
         | 
| 77 | 
            +
                    title="图像文件格式支持 PNG, JPG, JPEG 和 BMP, 且文件大小不超过 10M<br>Image file format support PNG, JPG, JPEG and BMP, and the file size does not exceed 10M.",
         | 
| 78 | 
            +
                    examples=[
         | 
| 79 | 
            +
                        f"{MODEL_DIR}/examples/butterfly.jpg",
         | 
| 80 | 
            +
                        f"{MODEL_DIR}/examples/beetle.jpg",
         | 
| 81 | 
            +
                    ],
         | 
| 82 | 
            +
                    allow_flagging="never",
         | 
| 83 | 
            +
                    cache_examples=False,
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                iface.launch()
         | 
    	
        insectid/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .detector import *
         | 
| 2 | 
            +
            from .identifier import *
         | 
    	
        insectid/base.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import onnxruntime
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class OnnxModel(object):
         | 
| 6 | 
            +
                def __init__(self, model_path):
         | 
| 7 | 
            +
                    sess_options = onnxruntime.SessionOptions()
         | 
| 8 | 
            +
                    # # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
         | 
| 9 | 
            +
                    # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
         | 
| 10 | 
            +
                    # # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
         | 
| 11 | 
            +
                    # sess_options.intra_op_num_threads = multiprocessing.cpu_count()
         | 
| 12 | 
            +
                    onnx_gpu = (onnxruntime.get_device() == 'GPU')
         | 
| 13 | 
            +
                    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
         | 
| 14 | 
            +
                    self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
         | 
| 15 | 
            +
                    self._input_names = [item.name for item in self.sess.get_inputs()]
         | 
| 16 | 
            +
                    self._output_names = [item.name for item in self.sess.get_outputs()]
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                @property
         | 
| 19 | 
            +
                def input_names(self):
         | 
| 20 | 
            +
                    return self._input_names
         | 
| 21 | 
            +
                    
         | 
| 22 | 
            +
                @property
         | 
| 23 | 
            +
                def output_names(self):
         | 
| 24 | 
            +
                    return self._output_names
         | 
| 25 | 
            +
                    
         | 
| 26 | 
            +
                def forward(self, inputs):
         | 
| 27 | 
            +
                    to_list_flag = False
         | 
| 28 | 
            +
                    if not isinstance(inputs, (tuple, list)):
         | 
| 29 | 
            +
                        inputs = [inputs]
         | 
| 30 | 
            +
                        to_list_flag = True
         | 
| 31 | 
            +
                    input_feed = {name: input for name, input in zip(self.input_names, inputs)}
         | 
| 32 | 
            +
                    outputs = self.sess.run(self.output_names, input_feed)
         | 
| 33 | 
            +
                    if (len(self.output_names) == 1) and to_list_flag:
         | 
| 34 | 
            +
                        return outputs[0]
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        return outputs
         | 
| 37 | 
            +
                        
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def check_image_dtype_and_shape(image):
         | 
| 40 | 
            +
                if not isinstance(image, np.ndarray):
         | 
| 41 | 
            +
                    raise Exception(f'image is not np.ndarray!')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if isinstance(image.dtype, (np.uint8, np.uint16)):
         | 
| 44 | 
            +
                    raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
         | 
| 45 | 
            +
                if image.ndim not in {2, 3}:
         | 
| 46 | 
            +
                    raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
         | 
| 47 | 
            +
                if image.ndim == 3:
         | 
| 48 | 
            +
                    num_channels = image.shape[-1]
         | 
| 49 | 
            +
                    if num_channels not in {1, 3, 4}:
         | 
| 50 | 
            +
                        raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
         | 
| 51 | 
            +
             | 
    	
        insectid/detector.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import khandy
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from .base import OnnxModel
         | 
| 5 | 
            +
            from .base import check_image_dtype_and_shape
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class InsectDetector(OnnxModel):
         | 
| 9 | 
            +
                def __init__(self):
         | 
| 10 | 
            +
                    current_dir = os.path.dirname(os.path.abspath(__file__))
         | 
| 11 | 
            +
                    model_path = os.path.join(
         | 
| 12 | 
            +
                        current_dir,
         | 
| 13 | 
            +
                        "__pycache__/MuGeminorum/insecta/quarrying_insect_detector.onnx",
         | 
| 14 | 
            +
                    )
         | 
| 15 | 
            +
                    self.input_width = 640
         | 
| 16 | 
            +
                    self.input_height = 640
         | 
| 17 | 
            +
                    super(InsectDetector, self).__init__(model_path)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def _preprocess(self, image):
         | 
| 20 | 
            +
                    check_image_dtype_and_shape(image)
         | 
| 21 | 
            +
                    # image size normalization
         | 
| 22 | 
            +
                    image, scale, pad_left, pad_top = khandy.letterbox_image(
         | 
| 23 | 
            +
                        image, self.input_width, self.input_height, 0, return_scale=True
         | 
| 24 | 
            +
                    )
         | 
| 25 | 
            +
                    # image channel normalization
         | 
| 26 | 
            +
                    image = khandy.normalize_image_channel(image, swap_rb=True)
         | 
| 27 | 
            +
                    # image dtype normalization
         | 
| 28 | 
            +
                    image = khandy.rescale_image(image, "auto", np.float32)
         | 
| 29 | 
            +
                    # to tensor
         | 
| 30 | 
            +
                    image = np.transpose(image, (2, 0, 1))
         | 
| 31 | 
            +
                    image = np.expand_dims(image, axis=0)
         | 
| 32 | 
            +
                    return image, scale, pad_left, pad_top
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def _post_process(
         | 
| 35 | 
            +
                    self, outputs_list, scale, pad_left, pad_top, conf_thresh, iou_thresh
         | 
| 36 | 
            +
                ):
         | 
| 37 | 
            +
                    pred = outputs_list[0][0]
         | 
| 38 | 
            +
                    pass_t = pred[:, 4] > conf_thresh
         | 
| 39 | 
            +
                    pred = pred[pass_t]
         | 
| 40 | 
            +
                    boxes = khandy.convert_boxes_format(pred[:, :4], "cxcywh", "xyxy")
         | 
| 41 | 
            +
                    boxes = khandy.unletterbox_2d_points(boxes, scale, pad_left, pad_top, False)
         | 
| 42 | 
            +
                    confs = np.max(pred[:, 5:] * pred[:, 4:5], axis=-1)
         | 
| 43 | 
            +
                    classes = np.argmax(pred[:, 5:] * pred[:, 4:5], axis=-1)
         | 
| 44 | 
            +
                    keep = khandy.non_max_suppression(boxes, confs, iou_thresh)
         | 
| 45 | 
            +
                    return boxes[keep], confs[keep], classes[keep]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def detect(self, image, conf_thresh=0.5, iou_thresh=0.5):
         | 
| 48 | 
            +
                    image, scale, pad_left, pad_top = self._preprocess(image)
         | 
| 49 | 
            +
                    outputs_list = self.forward(image)
         | 
| 50 | 
            +
                    boxes, confs, classes = self._post_process(
         | 
| 51 | 
            +
                        outputs_list,
         | 
| 52 | 
            +
                        scale=scale,
         | 
| 53 | 
            +
                        pad_left=pad_left,
         | 
| 54 | 
            +
                        pad_top=pad_top,
         | 
| 55 | 
            +
                        conf_thresh=conf_thresh,
         | 
| 56 | 
            +
                        iou_thresh=iou_thresh,
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    return boxes, confs, classes
         | 
    	
        insectid/identifier.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import khandy
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from .base import OnnxModel
         | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
            from .base import check_image_dtype_and_shape
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class InsectIdentifier(OnnxModel):
         | 
| 11 | 
            +
                def __init__(self):
         | 
| 12 | 
            +
                    current_dir = os.path.dirname(os.path.abspath(__file__))
         | 
| 13 | 
            +
                    model_path = os.path.join(
         | 
| 14 | 
            +
                        current_dir,
         | 
| 15 | 
            +
                        "__pycache__/MuGeminorum/insecta/quarrying_insect_identifier.onnx",
         | 
| 16 | 
            +
                    )
         | 
| 17 | 
            +
                    label_map_path = os.path.join(
         | 
| 18 | 
            +
                        current_dir,
         | 
| 19 | 
            +
                        "__pycache__/MuGeminorum/insecta/quarrying_insectid_label_map.txt",
         | 
| 20 | 
            +
                    )
         | 
| 21 | 
            +
                    super(InsectIdentifier, self).__init__(model_path)
         | 
| 22 | 
            +
                    self.label_name_dict = self._get_label_name_dict(label_map_path)
         | 
| 23 | 
            +
                    self.names = [
         | 
| 24 | 
            +
                        self.label_name_dict[i]["chinese_name"]
         | 
| 25 | 
            +
                        for i in range(len(self.label_name_dict))
         | 
| 26 | 
            +
                    ]
         | 
| 27 | 
            +
                    self.num_classes = len(self.label_name_dict)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @staticmethod
         | 
| 30 | 
            +
                def _get_label_name_dict(filename):
         | 
| 31 | 
            +
                    records = khandy.load_list(filename)
         | 
| 32 | 
            +
                    label_name_dict = {}
         | 
| 33 | 
            +
                    for record in records:
         | 
| 34 | 
            +
                        label, chinese_name, latin_name = record.split(",")
         | 
| 35 | 
            +
                        label_name_dict[int(label)] = OrderedDict(
         | 
| 36 | 
            +
                            [("chinese_name", chinese_name), ("latin_name", latin_name)]
         | 
| 37 | 
            +
                        )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    return label_name_dict
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @staticmethod
         | 
| 42 | 
            +
                def _preprocess(image):
         | 
| 43 | 
            +
                    check_image_dtype_and_shape(image)
         | 
| 44 | 
            +
                    # image size normalization
         | 
| 45 | 
            +
                    image = khandy.letterbox_image(image, 224, 224)
         | 
| 46 | 
            +
                    # image channel normalization
         | 
| 47 | 
            +
                    image = khandy.normalize_image_channel(image, swap_rb=True)
         | 
| 48 | 
            +
                    # image dtype normalization
         | 
| 49 | 
            +
                    # image dtype and value range normalization
         | 
| 50 | 
            +
                    mean, stddev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
         | 
| 51 | 
            +
                    image = khandy.normalize_image_value(image, mean, stddev, "auto")
         | 
| 52 | 
            +
                    # to tensor
         | 
| 53 | 
            +
                    image = np.transpose(image, (2, 0, 1))
         | 
| 54 | 
            +
                    image = np.expand_dims(image, axis=0)
         | 
| 55 | 
            +
                    return image
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def predict(self, image):
         | 
| 58 | 
            +
                    inputs = self._preprocess(image)
         | 
| 59 | 
            +
                    logits = self.forward(inputs)
         | 
| 60 | 
            +
                    probs = khandy.softmax(logits)
         | 
| 61 | 
            +
                    return probs
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def identify(self, image, topk=5):
         | 
| 64 | 
            +
                    assert isinstance(topk, int)
         | 
| 65 | 
            +
                    if topk <= 0 or topk > self.num_classes:
         | 
| 66 | 
            +
                        topk = self.num_classes
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    probs = self.predict(image)
         | 
| 69 | 
            +
                    topk_probs, topk_indices = khandy.top_k(probs, topk)
         | 
| 70 | 
            +
                    results = []
         | 
| 71 | 
            +
                    for ind, prob in zip(topk_indices[0], topk_probs[0]):
         | 
| 72 | 
            +
                        one_result = copy.deepcopy(self.label_name_dict[ind])
         | 
| 73 | 
            +
                        one_result["probability"] = prob
         | 
| 74 | 
            +
                        results.append(one_result)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return results
         | 
    	
        khandy/__init__.py
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .dict_utils import *
         | 
| 2 | 
            +
            from .draw_utils import *
         | 
| 3 | 
            +
            from .feature_utils import *
         | 
| 4 | 
            +
            from .file_io_utils import *
         | 
| 5 | 
            +
            from .fs_utils import *
         | 
| 6 | 
            +
            from .hash_utils import *
         | 
| 7 | 
            +
            from .list_utils import *
         | 
| 8 | 
            +
            from .misc import *
         | 
| 9 | 
            +
            from .numpy_utils import *
         | 
| 10 | 
            +
            from .split_utils import *
         | 
| 11 | 
            +
            from .text_utils import *
         | 
| 12 | 
            +
            from .time_utils import *
         | 
| 13 | 
            +
            from .version import *
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from .boxes import *
         | 
| 16 | 
            +
            from .image import *
         | 
| 17 | 
            +
            from .points import *
         | 
| 18 | 
            +
            from . import label
         | 
    	
        khandy/boxes/__init__.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .boxes_clip import *
         | 
| 2 | 
            +
            from .boxes_overlap import *
         | 
| 3 | 
            +
            from .boxes_filter import *
         | 
| 4 | 
            +
            from .boxes_convert import *
         | 
| 5 | 
            +
            from .boxes_coder import *
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .boxes_transform_flip import *
         | 
| 8 | 
            +
            from .boxes_transform_rotate import *
         | 
| 9 | 
            +
            from .boxes_transform_scale import *
         | 
| 10 | 
            +
            from .boxes_transform_translate import *
         | 
| 11 | 
            +
            from .boxes_utils import *
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .boxes_and_indices import *
         | 
    	
        khandy/boxes/boxes_and_indices.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def _concat(arr_list, axis=0):
         | 
| 5 | 
            +
                """Avoids a copy if there is only a single element in a list.
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                if len(arr_list) == 1:
         | 
| 8 | 
            +
                    return arr_list[0]
         | 
| 9 | 
            +
                return np.concatenate(arr_list, axis)
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
            def convert_boxes_list_to_boxes_and_indices(boxes_list):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                    boxes_list (np.ndarray): list or tuple of ndarray with shape (N_i, 4+K)
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                Returns:
         | 
| 18 | 
            +
                    boxes (ndarray): shape (M, 4+K) where M is sum of N_i.
         | 
| 19 | 
            +
                    indices (ndarray): shape (M, 1) where M is sum of N_i.
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                References:
         | 
| 22 | 
            +
                    `mmdet.core.bbox.bbox2roi` in mmdetection
         | 
| 23 | 
            +
                    `convert_boxes_to_roi_format` in TorchVision
         | 
| 24 | 
            +
                    `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                assert isinstance(boxes_list, (list, tuple))
         | 
| 27 | 
            +
                boxes = _concat(boxes_list, axis=0)
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                indices_list = [np.full((len(b), 1), i, boxes.dtype) 
         | 
| 30 | 
            +
                                for i, b in enumerate(boxes_list)]
         | 
| 31 | 
            +
                indices = _concat(indices_list, axis=0)
         | 
| 32 | 
            +
                return boxes, indices
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
            def convert_boxes_and_indices_to_boxes_list(boxes, indices, num_indices):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Args:
         | 
| 38 | 
            +
                    boxes (np.ndarray): shape (N, 4+K)
         | 
| 39 | 
            +
                    indices (np.ndarray): shape (N,) or (N, 1), maybe batch index 
         | 
| 40 | 
            +
                        in mini-batch or class label index.
         | 
| 41 | 
            +
                    num_indices (int): number of index.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                Returns:
         | 
| 44 | 
            +
                    list (ndarray): boxes list of each index
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                References:
         | 
| 47 | 
            +
                    `mmdet.core.bbox2result` in mmdetection
         | 
| 48 | 
            +
                    `mmdet.core.bbox.roi2bbox` in mmdetection
         | 
| 49 | 
            +
                    `convert_boxes_to_roi_format` in TorchVision
         | 
| 50 | 
            +
                    `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                boxes = np.asarray(boxes)
         | 
| 53 | 
            +
                indices = np.asarray(indices)
         | 
| 54 | 
            +
                assert boxes.ndim == 2, "boxes ndim must be 2, got {}".format(boxes.ndim)
         | 
| 55 | 
            +
                assert (indices.ndim == 1) or (indices.ndim == 2 and indices.shape[-1] == 1), \
         | 
| 56 | 
            +
                    "indices ndim must be 1 or 2 if last dimension size is 1, got shape {}".format(indices.shape)
         | 
| 57 | 
            +
                assert boxes.shape[0] == indices.shape[0], "the 1st dimension size of boxes and indices "\
         | 
| 58 | 
            +
                    "must be the same, got {} != {}".format(boxes.shape[0], indices.shape[0])
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                if boxes.shape[0] == 0:
         | 
| 61 | 
            +
                    return [np.zeros((0, boxes.shape[1]), dtype=np.float32) 
         | 
| 62 | 
            +
                            for i in range(num_indices)]
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    if indices.ndim == 2:
         | 
| 65 | 
            +
                        indices = np.squeeze(indices, axis=-1)
         | 
| 66 | 
            +
                    return [boxes[indices == i, :] for i in range(num_indices)]
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
                
         | 
    	
        khandy/boxes/boxes_clip.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def clip_boxes(boxes, reference_box, copy=True):
         | 
| 5 | 
            +
                """Clip boxes to reference box.
         | 
| 6 | 
            +
                
         | 
| 7 | 
            +
                References:
         | 
| 8 | 
            +
                    `clip_to_window` in TensorFlow object detection API.
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                if copy:
         | 
| 11 | 
            +
                    boxes = boxes.copy()
         | 
| 12 | 
            +
                ref_x_min, ref_y_min, ref_x_max, ref_y_max = reference_box[:4]
         | 
| 13 | 
            +
                lower = np.array([ref_x_min, ref_y_min, ref_x_min, ref_y_min])
         | 
| 14 | 
            +
                upper = np.array([ref_x_max, ref_y_max, ref_x_max, ref_y_max])
         | 
| 15 | 
            +
                np.clip(boxes[..., :4], lower, upper, boxes[..., :4])
         | 
| 16 | 
            +
                return boxes
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
            def clip_boxes_to_image(boxes, image_width, image_height, subpixel=True, copy=True):
         | 
| 20 | 
            +
                """Clip boxes to image boundaries.
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                References:
         | 
| 23 | 
            +
                    `clip_boxes` in py-faster-rcnn
         | 
| 24 | 
            +
                    `core.boxes_op_list.clip_to_window` in TensorFlow object detection API.
         | 
| 25 | 
            +
                    `structures.Boxes.clip` in detectron2
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                Notes:
         | 
| 28 | 
            +
                    Equivalent to `clip_boxes(boxes, [0,0,image_width-1,image_height-1], copy)`
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                if not subpixel:
         | 
| 31 | 
            +
                    image_width -= 1
         | 
| 32 | 
            +
                    image_height -= 1
         | 
| 33 | 
            +
                reference_box = [0, 0, image_width, image_height]
         | 
| 34 | 
            +
                return clip_boxes(boxes, reference_box, copy)
         | 
    	
        khandy/boxes/boxes_coder.py
    ADDED
    
    | @@ -0,0 +1,69 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class FasterRcnnBoxCoder:
         | 
| 5 | 
            +
                """Faster RCNN box coder.
         | 
| 6 | 
            +
                
         | 
| 7 | 
            +
                Notes:
         | 
| 8 | 
            +
                    boxes should be in cxcywh format.
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                def __init__(self, stddevs=None):
         | 
| 12 | 
            +
                    """Constructor for FasterRcnnBoxCoder.
         | 
| 13 | 
            +
                  
         | 
| 14 | 
            +
                    Args:
         | 
| 15 | 
            +
                      stddevs: List of 4 positive scalars to scale ty, tx, th and tw.
         | 
| 16 | 
            +
                        If set to None, does not perform scaling. For Faster RCNN,
         | 
| 17 | 
            +
                        the open-source implementation recommends using [0.1, 0.1, 0.2, 0.2].
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    if stddevs:
         | 
| 20 | 
            +
                        assert len(stddevs) == 4
         | 
| 21 | 
            +
                        for scalar in stddevs:
         | 
| 22 | 
            +
                            assert scalar > 0
         | 
| 23 | 
            +
                    self.stddevs = stddevs
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def encode(self, boxes, reference_boxes, copy=True):
         | 
| 26 | 
            +
                    """Encode boxes with respect to reference boxes.
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    if copy:
         | 
| 29 | 
            +
                        boxes = boxes.copy()
         | 
| 30 | 
            +
                        
         | 
| 31 | 
            +
                    boxes[..., 2:4] += 1e-8
         | 
| 32 | 
            +
                    reference_boxes[..., 2:4] += 1e-8
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                    boxes[..., 0:2] -= reference_boxes[..., 0:2]
         | 
| 35 | 
            +
                    boxes[..., 0:2] /= reference_boxes[..., 2:4]
         | 
| 36 | 
            +
                    boxes[..., 2:4] /= reference_boxes[..., 2:4]
         | 
| 37 | 
            +
                    boxes[..., 2:4] = np.log(boxes[..., 2:4], boxes[..., 2:4])
         | 
| 38 | 
            +
                    if self.stddevs:
         | 
| 39 | 
            +
                        boxes[..., 0:4] /= self.stddevs
         | 
| 40 | 
            +
                    return boxes
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def decode(self, rel_boxes, reference_boxes, copy=True):
         | 
| 43 | 
            +
                    """Decode relative codes to boxes.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    if copy:
         | 
| 46 | 
            +
                        rel_boxes = rel_boxes.copy()
         | 
| 47 | 
            +
                        
         | 
| 48 | 
            +
                    if self.stddevs:
         | 
| 49 | 
            +
                        rel_boxes[..., 0:4] *= self.stddevs
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    rel_boxes[..., 0:2] *= reference_boxes[..., 2:4]
         | 
| 52 | 
            +
                    rel_boxes[..., 0:2] += reference_boxes[..., 0:2]
         | 
| 53 | 
            +
                    rel_boxes[..., 2:4] = np.exp(rel_boxes[..., 2:4], rel_boxes[..., 2:4])
         | 
| 54 | 
            +
                    rel_boxes[..., 2:4] *= reference_boxes[..., 2:4]
         | 
| 55 | 
            +
                    return rel_boxes
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                def decode_points(self, rel_points, reference_boxes, copy=True):
         | 
| 58 | 
            +
                    """Decode relative codes to points.
         | 
| 59 | 
            +
                    """
         | 
| 60 | 
            +
                    if copy:
         | 
| 61 | 
            +
                        rel_points = rel_points.copy()
         | 
| 62 | 
            +
                    if self.stddevs:
         | 
| 63 | 
            +
                        rel_points[..., 0::2] *= self.stddevs[0]
         | 
| 64 | 
            +
                        rel_points[..., 1::2] *= self.stddevs[1]
         | 
| 65 | 
            +
                    rel_points[..., 0::2] *= reference_boxes[..., 2:3]
         | 
| 66 | 
            +
                    rel_points[..., 1::2] *= reference_boxes[..., 3:4]
         | 
| 67 | 
            +
                    rel_points[..., 0::2] += reference_boxes[..., 0:1]
         | 
| 68 | 
            +
                    rel_points[..., 1::2] += reference_boxes[..., 1:2]
         | 
| 69 | 
            +
                    return rel_points
         | 
    	
        khandy/boxes/boxes_convert.py
    ADDED
    
    | @@ -0,0 +1,101 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def convert_xyxy_to_xywh(boxes, copy=True):
         | 
| 5 | 
            +
                """Convert [x_min, y_min, x_max, y_max] format to [x_min, y_min, width, height] format.
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                if copy:
         | 
| 8 | 
            +
                    boxes = boxes.copy()
         | 
| 9 | 
            +
                boxes[..., 2:4] -= boxes[..., 0:2]
         | 
| 10 | 
            +
                return boxes
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def convert_xywh_to_xyxy(boxes, copy=True):
         | 
| 14 | 
            +
                """Convert [x_min, y_min, width, height] format to [x_min, y_min, x_max, y_max] format.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                if copy:
         | 
| 17 | 
            +
                    boxes = boxes.copy()
         | 
| 18 | 
            +
                boxes[..., 2:4] += boxes[..., 0:2]
         | 
| 19 | 
            +
                return boxes
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def convert_xywh_to_cxcywh(boxes, copy=True):
         | 
| 23 | 
            +
                """Convert [x_min, y_min, width, height] format to [cx, cy, width, height] format.
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                if copy:
         | 
| 26 | 
            +
                    boxes = boxes.copy()
         | 
| 27 | 
            +
                boxes[..., 0:2] += boxes[..., 2:4] * 0.5
         | 
| 28 | 
            +
                return boxes
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
            def convert_cxcywh_to_xywh(boxes, copy=True):
         | 
| 32 | 
            +
                """Convert [cx, cy, width, height] format to [x_min, y_min, width, height] format.
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                if copy:
         | 
| 35 | 
            +
                    boxes = boxes.copy()
         | 
| 36 | 
            +
                boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
         | 
| 37 | 
            +
                return boxes
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
            def convert_xyxy_to_cxcywh(boxes, copy=True):
         | 
| 41 | 
            +
                """Convert [x_min, y_min, x_max, y_max] format to [cx, cy, width, height] format.
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                if copy:
         | 
| 44 | 
            +
                    boxes = boxes.copy()
         | 
| 45 | 
            +
                boxes[..., 2:4] -= boxes[..., 0:2]
         | 
| 46 | 
            +
                boxes[..., 0:2] += boxes[..., 2:4] * 0.5
         | 
| 47 | 
            +
                return boxes
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def convert_cxcywh_to_xyxy(boxes, copy=True):
         | 
| 51 | 
            +
                """Convert [cx, cy, width, height] format to [x_min, y_min, x_max, y_max] format.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                if copy:
         | 
| 54 | 
            +
                    boxes = boxes.copy()
         | 
| 55 | 
            +
                boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
         | 
| 56 | 
            +
                boxes[..., 2:4] += boxes[..., 0:2]
         | 
| 57 | 
            +
                return boxes
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def convert_boxes_format(boxes, in_fmt, out_fmt, copy=True):
         | 
| 61 | 
            +
                """Converts boxes from given in_fmt to out_fmt.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                Supported in_fmt and out_fmt are:
         | 
| 64 | 
            +
                    'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
         | 
| 65 | 
            +
                    'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
         | 
| 66 | 
            +
                    'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
         | 
| 67 | 
            +
                        being width and height.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Args:
         | 
| 70 | 
            +
                    boxes: boxes which will be converted.
         | 
| 71 | 
            +
                    in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
         | 
| 72 | 
            +
                    out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Returns:
         | 
| 75 | 
            +
                    boxes: Boxes into converted format.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                References:
         | 
| 78 | 
            +
                    torchvision.ops.box_convert
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                allowed_fmts = ("xyxy", "xywh", "cxcywh")
         | 
| 81 | 
            +
                if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
         | 
| 82 | 
            +
                    raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
         | 
| 83 | 
            +
                if copy:
         | 
| 84 | 
            +
                    boxes = boxes.copy()
         | 
| 85 | 
            +
                if in_fmt == out_fmt:
         | 
| 86 | 
            +
                    return boxes
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if (in_fmt, out_fmt) == ("xyxy", "xywh"):
         | 
| 89 | 
            +
                    boxes = convert_xyxy_to_xywh(boxes, copy=False)
         | 
| 90 | 
            +
                elif (in_fmt, out_fmt) == ("xywh", "xyxy"):
         | 
| 91 | 
            +
                    boxes = convert_xywh_to_xyxy(boxes, copy=False)
         | 
| 92 | 
            +
                elif (in_fmt, out_fmt) == ("xywh", "cxcywh"):
         | 
| 93 | 
            +
                    boxes = convert_xywh_to_cxcywh(boxes, copy=False)
         | 
| 94 | 
            +
                elif (in_fmt, out_fmt) == ("cxcywh", "xywh"):
         | 
| 95 | 
            +
                    boxes = convert_cxcywh_to_xywh(boxes, copy=False)
         | 
| 96 | 
            +
                elif (in_fmt, out_fmt) == ("xyxy", "cxcywh"):
         | 
| 97 | 
            +
                    boxes = convert_xyxy_to_cxcywh(boxes, copy=False)
         | 
| 98 | 
            +
                elif (in_fmt, out_fmt) == ("cxcywh", "xyxy"):
         | 
| 99 | 
            +
                    boxes = convert_cxcywh_to_xyxy(boxes, copy=False)
         | 
| 100 | 
            +
                return boxes
         | 
| 101 | 
            +
                
         | 
    	
        khandy/boxes/boxes_filter.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def filter_small_boxes(boxes, min_width, min_height):
         | 
| 5 | 
            +
                """Filters all boxes with side smaller than min size. 
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                Args:
         | 
| 8 | 
            +
                    boxes: a numpy array with shape [N, 4] holding N boxes.
         | 
| 9 | 
            +
                    min_width (float): minimum width
         | 
| 10 | 
            +
                    min_height (float): minimum height
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Returns:
         | 
| 13 | 
            +
                    keep: indices of the boxes that have width larger than
         | 
| 14 | 
            +
                        min_width and height larger than min_height.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                References:
         | 
| 17 | 
            +
                    `_filter_boxes` in py-faster-rcnn
         | 
| 18 | 
            +
                    `prune_small_boxes` in TensorFlow object detection API.
         | 
| 19 | 
            +
                    `structures.Boxes.nonempty` in detectron2
         | 
| 20 | 
            +
                    `ops.boxes.remove_small_boxes` in torchvision
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                widths = boxes[:, 2] - boxes[:, 0]
         | 
| 23 | 
            +
                heights = boxes[:, 3] - boxes[:, 1]
         | 
| 24 | 
            +
                # keep represents indices to keep, 
         | 
| 25 | 
            +
                # mask represents bool ndarray, so use mask here.
         | 
| 26 | 
            +
                mask = (widths >= min_width)
         | 
| 27 | 
            +
                mask &= (heights >= min_height)
         | 
| 28 | 
            +
                return np.nonzero(mask)[0]
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            def filter_boxes_outside(boxes, reference_box):
         | 
| 32 | 
            +
                """Filters bounding boxes that fall outside reference box.
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                References:
         | 
| 35 | 
            +
                    `prune_outside_window` in TensorFlow object detection API.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                x_min, y_min, x_max, y_max = reference_box[:4]
         | 
| 38 | 
            +
                mask = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
         | 
| 39 | 
            +
                        (boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
         | 
| 40 | 
            +
                return np.nonzero(mask)[0]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def filter_boxes_completely_outside(boxes, reference_box):
         | 
| 44 | 
            +
                """Filters bounding boxes that fall completely outside of reference box.
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                References:
         | 
| 47 | 
            +
                    `prune_completely_outside_window` in TensorFlow object detection API.
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                x_min, y_min, x_max, y_max = reference_box[:4]
         | 
| 50 | 
            +
                mask = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
         | 
| 51 | 
            +
                        (boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
         | 
| 52 | 
            +
                return np.nonzero(mask)[0]
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def non_max_suppression(boxes, scores, thresh, classes=None, ratio_type="iou"):
         | 
| 56 | 
            +
                """Greedily select boxes with high confidence
         | 
| 57 | 
            +
                Args:
         | 
| 58 | 
            +
                    boxes: [[x_min, y_min, x_max, y_max], ...]
         | 
| 59 | 
            +
                    scores: object confidence
         | 
| 60 | 
            +
                    thresh: retain overlap_ratio <= thresh
         | 
| 61 | 
            +
                    classes: class labels
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                Returns:
         | 
| 64 | 
            +
                    indices to keep
         | 
| 65 | 
            +
                    
         | 
| 66 | 
            +
                References:
         | 
| 67 | 
            +
                    `py_cpu_nms` in py-faster-rcnn
         | 
| 68 | 
            +
                    torchvision.ops.nms
         | 
| 69 | 
            +
                    torchvision.ops.batched_nms
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                if boxes.size == 0:
         | 
| 73 | 
            +
                    return np.empty((0,), dtype=np.int64)
         | 
| 74 | 
            +
                if classes is not None:
         | 
| 75 | 
            +
                    # strategy: in order to perform NMS independently per class,
         | 
| 76 | 
            +
                    # we add an offset to all the boxes. The offset is dependent
         | 
| 77 | 
            +
                    # only on the class idx, and is large enough so that boxes
         | 
| 78 | 
            +
                    # from different classes do not overlap
         | 
| 79 | 
            +
                    max_coordinate = np.max(boxes)
         | 
| 80 | 
            +
                    offsets = classes * (max_coordinate + 1)
         | 
| 81 | 
            +
                    boxes = boxes + offsets[:, None]
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                x_mins = boxes[:, 0]
         | 
| 84 | 
            +
                y_mins = boxes[:, 1]
         | 
| 85 | 
            +
                x_maxs = boxes[:, 2]
         | 
| 86 | 
            +
                y_maxs = boxes[:, 3]
         | 
| 87 | 
            +
                areas = (x_maxs - x_mins) * (y_maxs - y_mins)
         | 
| 88 | 
            +
                order = scores.flatten().argsort()[::-1]
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                keep = []
         | 
| 91 | 
            +
                while order.size > 0:
         | 
| 92 | 
            +
                    i = order[0]
         | 
| 93 | 
            +
                    keep.append(i)
         | 
| 94 | 
            +
                    
         | 
| 95 | 
            +
                    max_x_mins = np.maximum(x_mins[i], x_mins[order[1:]])
         | 
| 96 | 
            +
                    max_y_mins = np.maximum(y_mins[i], y_mins[order[1:]])
         | 
| 97 | 
            +
                    min_x_maxs = np.minimum(x_maxs[i], x_maxs[order[1:]])
         | 
| 98 | 
            +
                    min_y_maxs = np.minimum(y_maxs[i], y_maxs[order[1:]])
         | 
| 99 | 
            +
                    widths = np.maximum(0, min_x_maxs - max_x_mins)
         | 
| 100 | 
            +
                    heights = np.maximum(0, min_y_maxs - max_y_mins)
         | 
| 101 | 
            +
                    intersect_areas = widths * heights
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    if ratio_type in ["union", 'iou']:
         | 
| 104 | 
            +
                        ratio = intersect_areas / (areas[i] + areas[order[1:]] - intersect_areas)
         | 
| 105 | 
            +
                    elif ratio_type == "min":
         | 
| 106 | 
            +
                        ratio = intersect_areas / np.minimum(areas[i], areas[order[1:]])
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
         | 
| 109 | 
            +
                        
         | 
| 110 | 
            +
                    inds = np.nonzero(ratio <= thresh)[0]
         | 
| 111 | 
            +
                    order = order[inds + 1]
         | 
| 112 | 
            +
                return np.asarray(keep)
         | 
| 113 | 
            +
                
         | 
    	
        khandy/boxes/boxes_overlap.py
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def paired_intersection(boxes1, boxes2):
         | 
| 5 | 
            +
                """Compute paired intersection areas between boxes.
         | 
| 6 | 
            +
                Args:
         | 
| 7 | 
            +
                    boxes1: a numpy array with shape [N, 4] holding N boxes
         | 
| 8 | 
            +
                    boxes2: a numpy array with shape [N, 4] holding N boxes
         | 
| 9 | 
            +
                    
         | 
| 10 | 
            +
                Returns:
         | 
| 11 | 
            +
                    a numpy array with shape [N,] representing itemwise intersection area
         | 
| 12 | 
            +
                    
         | 
| 13 | 
            +
                References:
         | 
| 14 | 
            +
                    `core.box_list_ops.matched_intersection` in Tensorflow object detection API
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                Notes:
         | 
| 17 | 
            +
                    can called as itemwise_intersection, matched_intersection, aligned_intersection
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                max_x_mins = np.maximum(boxes1[:, 0], boxes2[:, 0])
         | 
| 20 | 
            +
                max_y_mins = np.maximum(boxes1[:, 1], boxes2[:, 1])
         | 
| 21 | 
            +
                min_x_maxs = np.minimum(boxes1[:, 2], boxes2[:, 2])
         | 
| 22 | 
            +
                min_y_maxs = np.minimum(boxes1[:, 3], boxes2[:, 3])
         | 
| 23 | 
            +
                intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
         | 
| 24 | 
            +
                intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
         | 
| 25 | 
            +
                return intersect_widths * intersect_heights
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
            def pairwise_intersection(boxes1, boxes2):
         | 
| 29 | 
            +
                """Compute pairwise intersection areas between boxes.
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                Args:
         | 
| 32 | 
            +
                    boxes1: a numpy array with shape [N, 4] holding N boxes.
         | 
| 33 | 
            +
                    boxes2: a numpy array with shape [M, 4] holding M boxes.
         | 
| 34 | 
            +
                    
         | 
| 35 | 
            +
                Returns:
         | 
| 36 | 
            +
                    a numpy array with shape [N, M] representing pairwise intersection area.
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                References:
         | 
| 39 | 
            +
                    `core.box_list_ops.intersection` in Tensorflow object detection API
         | 
| 40 | 
            +
                    `utils.box_list_ops.intersection` in Tensorflow object detection API
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                if boxes1.shape[0] * boxes2.shape[0] == 0:
         | 
| 43 | 
            +
                    return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                swap = False
         | 
| 46 | 
            +
                if boxes1.shape[0] > boxes2.shape[0]:
         | 
| 47 | 
            +
                    boxes1, boxes2 = boxes2, boxes1
         | 
| 48 | 
            +
                    swap = True
         | 
| 49 | 
            +
                intersect_areas = np.empty((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                for i in range(boxes1.shape[0]):
         | 
| 52 | 
            +
                    max_x_mins = np.maximum(boxes1[i, 0], boxes2[:, 0])
         | 
| 53 | 
            +
                    max_y_mins = np.maximum(boxes1[i, 1], boxes2[:, 1])
         | 
| 54 | 
            +
                    min_x_maxs = np.minimum(boxes1[i, 2], boxes2[:, 2])
         | 
| 55 | 
            +
                    min_y_maxs = np.minimum(boxes1[i, 3], boxes2[:, 3])
         | 
| 56 | 
            +
                    intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
         | 
| 57 | 
            +
                    intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
         | 
| 58 | 
            +
                    intersect_areas[i, :] = intersect_widths * intersect_heights
         | 
| 59 | 
            +
                if swap:
         | 
| 60 | 
            +
                    intersect_areas = intersect_areas.T
         | 
| 61 | 
            +
                return intersect_areas
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
            def paired_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
         | 
| 65 | 
            +
                """Compute paired overlap ratio between boxes.
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                Args:
         | 
| 68 | 
            +
                    boxes1: a numpy array with shape [N, 4] holding N boxes
         | 
| 69 | 
            +
                    boxes2: a numpy array with shape [N, 4] holding N boxes
         | 
| 70 | 
            +
                    ratio_type:
         | 
| 71 | 
            +
                        iou: Intersection-over-union (iou).
         | 
| 72 | 
            +
                        ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
         | 
| 73 | 
            +
                            their intersection area over box2's area. Note that ioa is not symmetric,
         | 
| 74 | 
            +
                            that is, IOA(box1, box2) != IOA(box2, box1).
         | 
| 75 | 
            +
                        min: Compute the ratio as the area of intersection between box1 and box2, 
         | 
| 76 | 
            +
                            divided by the minimum area of the two bounding boxes.
         | 
| 77 | 
            +
                            
         | 
| 78 | 
            +
                Returns:
         | 
| 79 | 
            +
                    a numpy array with shape [N,] representing itemwise overlap ratio.
         | 
| 80 | 
            +
                    
         | 
| 81 | 
            +
                References:
         | 
| 82 | 
            +
                    `core.box_list_ops.matched_iou` in Tensorflow object detection API
         | 
| 83 | 
            +
                    `structures.boxes.matched_boxlist_iou` in detectron2
         | 
| 84 | 
            +
                    `mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                intersect_areas = paired_intersection(boxes1, boxes2)
         | 
| 87 | 
            +
                areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         | 
| 88 | 
            +
                areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                if ratio_type in ['union', 'iou', 'giou']:
         | 
| 91 | 
            +
                    union_areas = areas1 - intersect_areas
         | 
| 92 | 
            +
                    union_areas += areas2
         | 
| 93 | 
            +
                    intersect_areas /= union_areas
         | 
| 94 | 
            +
                elif ratio_type == 'min':
         | 
| 95 | 
            +
                    min_areas = np.minimum(areas1, areas2)
         | 
| 96 | 
            +
                    intersect_areas /= min_areas
         | 
| 97 | 
            +
                elif ratio_type == 'ioa':
         | 
| 98 | 
            +
                    intersect_areas /= areas2
         | 
| 99 | 
            +
                else:
         | 
| 100 | 
            +
                    raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
         | 
| 101 | 
            +
                    
         | 
| 102 | 
            +
                if ratio_type == 'giou':
         | 
| 103 | 
            +
                    min_xy_mins = np.minimum(boxes1[:, 0:2], boxes2[:, 0:2])
         | 
| 104 | 
            +
                    max_xy_mins = np.maximum(boxes1[:, 2:4], boxes2[:, 2:4])
         | 
| 105 | 
            +
                    # mebb = minimum enclosing bounding boxes
         | 
| 106 | 
            +
                    mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
         | 
| 107 | 
            +
                    mebb_areas = mebb_whs[:, 0] * mebb_whs[:, 1]
         | 
| 108 | 
            +
                    union_areas -= mebb_areas
         | 
| 109 | 
            +
                    union_areas /= mebb_areas
         | 
| 110 | 
            +
                    intersect_areas += union_areas
         | 
| 111 | 
            +
                return intersect_areas
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def pairwise_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
         | 
| 115 | 
            +
                """Compute pairwise overlap ratio between boxes.
         | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                Args:
         | 
| 118 | 
            +
                    boxes1: a numpy array with shape [N, 4] holding N boxes
         | 
| 119 | 
            +
                    boxes2: a numpy array with shape [M, 4] holding M boxes
         | 
| 120 | 
            +
                    ratio_type:
         | 
| 121 | 
            +
                        iou: Intersection-over-union (iou).
         | 
| 122 | 
            +
                        ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
         | 
| 123 | 
            +
                            their intersection area over box2's area. Note that ioa is not symmetric,
         | 
| 124 | 
            +
                            that is, IOA(box1, box2) != IOA(box2, box1).
         | 
| 125 | 
            +
                        min: Compute the ratio as the area of intersection between box1 and box2, 
         | 
| 126 | 
            +
                            divided by the minimum area of the two bounding boxes.
         | 
| 127 | 
            +
                            
         | 
| 128 | 
            +
                Returns:
         | 
| 129 | 
            +
                    a numpy array with shape [N, M] representing pairwise overlap ratio.
         | 
| 130 | 
            +
                    
         | 
| 131 | 
            +
                References:
         | 
| 132 | 
            +
                    `utils.np_box_ops.iou` in Tensorflow object detection API
         | 
| 133 | 
            +
                    `utils.np_box_ops.ioa` in Tensorflow object detection API
         | 
| 134 | 
            +
                    `utils.np_box_ops.giou` in Tensorflow object detection API
         | 
| 135 | 
            +
                    `mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
         | 
| 136 | 
            +
                    `torchvision.ops.box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.box_iou
         | 
| 137 | 
            +
                    `torchvision.ops.generalized_box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.generalized_box_iou
         | 
| 138 | 
            +
                    http://ww2.mathworks.cn/help/vision/ref/bboxoverlapratio.html
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                intersect_areas = pairwise_intersection(boxes1, boxes2)
         | 
| 141 | 
            +
                areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         | 
| 142 | 
            +
                areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         | 
| 143 | 
            +
                
         | 
| 144 | 
            +
                if ratio_type in ['union', 'iou', 'giou']:
         | 
| 145 | 
            +
                    union_areas = np.expand_dims(areas1, axis=1) - intersect_areas
         | 
| 146 | 
            +
                    union_areas += np.expand_dims(areas2, axis=0)
         | 
| 147 | 
            +
                    intersect_areas /= union_areas
         | 
| 148 | 
            +
                elif ratio_type == 'min':
         | 
| 149 | 
            +
                    min_areas = np.minimum(np.expand_dims(areas1, axis=1), np.expand_dims(areas2, axis=0))
         | 
| 150 | 
            +
                    intersect_areas /= min_areas
         | 
| 151 | 
            +
                elif ratio_type == 'ioa':
         | 
| 152 | 
            +
                    intersect_areas /= np.expand_dims(areas2, axis=0)
         | 
| 153 | 
            +
                else:
         | 
| 154 | 
            +
                    raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
         | 
| 155 | 
            +
                    
         | 
| 156 | 
            +
                if ratio_type == 'giou':
         | 
| 157 | 
            +
                    min_xy_mins = np.minimum(boxes1[:, None, 0:2], boxes2[:, 0:2])
         | 
| 158 | 
            +
                    max_xy_mins = np.maximum(boxes1[:, None, 2:4], boxes2[:, 2:4])
         | 
| 159 | 
            +
                    # mebb = minimum enclosing bounding boxes
         | 
| 160 | 
            +
                    mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
         | 
| 161 | 
            +
                    mebb_areas = mebb_whs[:, :, 0] * mebb_whs[:, :, 1]
         | 
| 162 | 
            +
                    union_areas -= mebb_areas
         | 
| 163 | 
            +
                    union_areas /= mebb_areas
         | 
| 164 | 
            +
                    intersect_areas += union_areas
         | 
| 165 | 
            +
                return intersect_areas
         | 
| 166 | 
            +
                
         | 
    	
        khandy/boxes/boxes_transform_flip.py
    ADDED
    
    | @@ -0,0 +1,135 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from .boxes_utils import assert_and_normalize_shape
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def flip_boxes(boxes, x_center=0, y_center=0, direction='h'):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Args:
         | 
| 8 | 
            +
                    boxes: (N, 4+K)
         | 
| 9 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 10 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 11 | 
            +
                    direction: str
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                assert direction in ['x', 'h', 'horizontal',
         | 
| 14 | 
            +
                                     'y', 'v', 'vertical', 
         | 
| 15 | 
            +
                                     'o', 'b', 'both']
         | 
| 16 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 17 | 
            +
                ret_boxes = boxes.copy()
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 20 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 21 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 22 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
         | 
| 25 | 
            +
                    ret_boxes[:, 0] = 2 * x_center - boxes[:, 2] 
         | 
| 26 | 
            +
                    ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
         | 
| 27 | 
            +
                if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
         | 
| 28 | 
            +
                    ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
         | 
| 29 | 
            +
                    ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
         | 
| 30 | 
            +
                return ret_boxes
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
            def fliplr_boxes(boxes, x_center=0, y_center=0):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Args:
         | 
| 36 | 
            +
                    boxes: (N, 4+K)
         | 
| 37 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 38 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 41 | 
            +
                ret_boxes = boxes.copy()
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 44 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 45 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 46 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 47 | 
            +
                 
         | 
| 48 | 
            +
                ret_boxes[:, 0] = 2 * x_center - boxes[:, 2] 
         | 
| 49 | 
            +
                ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
         | 
| 50 | 
            +
                return ret_boxes
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
            def flipud_boxes(boxes, x_center=0, y_center=0):
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                Args:
         | 
| 56 | 
            +
                    boxes: (N, 4+K)
         | 
| 57 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 58 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 61 | 
            +
                ret_boxes = boxes.copy()
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 64 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 65 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 66 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
                ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
         | 
| 69 | 
            +
                ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
         | 
| 70 | 
            +
                return ret_boxes
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
            def transpose_boxes(boxes, x_center=0, y_center=0):
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                Args:
         | 
| 76 | 
            +
                    boxes: (N, 4+K)
         | 
| 77 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 78 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 81 | 
            +
                ret_boxes = boxes.copy()
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 84 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 85 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 86 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                shift = x_center - y_center
         | 
| 89 | 
            +
                ret_boxes[:, 0] = boxes[:, 1] + shift
         | 
| 90 | 
            +
                ret_boxes[:, 1] = boxes[:, 0] - shift
         | 
| 91 | 
            +
                ret_boxes[:, 2] = boxes[:, 3] + shift
         | 
| 92 | 
            +
                ret_boxes[:, 3] = boxes[:, 2] - shift
         | 
| 93 | 
            +
                return ret_boxes
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def flip_boxes_in_image(boxes, image_width, image_height, direction='h'):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Args:
         | 
| 99 | 
            +
                    boxes: (N, 4+K)
         | 
| 100 | 
            +
                    image_width: int
         | 
| 101 | 
            +
                    image_width: int
         | 
| 102 | 
            +
                    direction: str
         | 
| 103 | 
            +
                    
         | 
| 104 | 
            +
                References:
         | 
| 105 | 
            +
                    `core.bbox.bbox_flip` in mmdetection
         | 
| 106 | 
            +
                    `datasets.pipelines.RandomFlip.bbox_flip` in mmdetection
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                x_center = (image_width - 1) * 0.5
         | 
| 109 | 
            +
                y_center = (image_height - 1) * 0.5
         | 
| 110 | 
            +
                ret_boxes = flip_boxes(boxes, x_center, y_center, direction=direction)
         | 
| 111 | 
            +
                return ret_boxes
         | 
| 112 | 
            +
                
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
            def rot90_boxes_in_image(boxes, image_width, image_height, n=1):
         | 
| 115 | 
            +
                """Rotate boxes counter-clockwise by 90 degrees.
         | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                References:
         | 
| 118 | 
            +
                    np.rot90
         | 
| 119 | 
            +
                    cv2.rotate
         | 
| 120 | 
            +
                    tf.image.rot90
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                n = n % 4
         | 
| 123 | 
            +
                if n == 0:
         | 
| 124 | 
            +
                    ret_boxes = boxes.copy()
         | 
| 125 | 
            +
                elif n == 1:
         | 
| 126 | 
            +
                    ret_boxes = transpose_boxes(boxes)
         | 
| 127 | 
            +
                    ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'v')
         | 
| 128 | 
            +
                elif n == 2:
         | 
| 129 | 
            +
                    ret_boxes = flip_boxes_in_image(boxes, image_width, image_height, 'o')
         | 
| 130 | 
            +
                else:
         | 
| 131 | 
            +
                    ret_boxes = transpose_boxes(boxes)
         | 
| 132 | 
            +
                    ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'h');
         | 
| 133 | 
            +
                return ret_boxes
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                
         | 
    	
        khandy/boxes/boxes_transform_rotate.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from .boxes_utils import assert_and_normalize_shape
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def rotate_boxes(boxes, angle, x_center=0, y_center=0, scale=1, 
         | 
| 6 | 
            +
                             degrees=True, return_rotated_boxes=False):
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    boxes: (N, 4+K)
         | 
| 10 | 
            +
                    angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 11 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 12 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 13 | 
            +
                    scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 14 | 
            +
                        scale factor in x and y dimension
         | 
| 15 | 
            +
                    degrees: bool
         | 
| 16 | 
            +
                    return_rotated_boxes: bool
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                angle = np.asarray(angle, np.float32)
         | 
| 21 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 22 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 23 | 
            +
                scale = np.asarray(scale, np.float32)
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                angle = assert_and_normalize_shape(angle, boxes.shape[0])
         | 
| 26 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 27 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 28 | 
            +
                scale = assert_and_normalize_shape(scale, boxes.shape[0])
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                if degrees:
         | 
| 31 | 
            +
                    angle = np.deg2rad(angle)
         | 
| 32 | 
            +
                cos_val = scale * np.cos(angle)
         | 
| 33 | 
            +
                sin_val = scale * np.sin(angle)
         | 
| 34 | 
            +
                x_shift = x_center - x_center * cos_val + y_center * sin_val
         | 
| 35 | 
            +
                y_shift = y_center - x_center * sin_val - y_center * cos_val
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                x_mins, y_mins = boxes[:,0], boxes[:,1]
         | 
| 38 | 
            +
                x_maxs, y_maxs = boxes[:,2], boxes[:,3]
         | 
| 39 | 
            +
                x00 = x_mins * cos_val - y_mins * sin_val + x_shift
         | 
| 40 | 
            +
                x10 = x_maxs * cos_val - y_mins * sin_val + x_shift
         | 
| 41 | 
            +
                x11 = x_maxs * cos_val - y_maxs * sin_val + x_shift
         | 
| 42 | 
            +
                x01 = x_mins * cos_val - y_maxs * sin_val + x_shift
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                y00 = x_mins * sin_val + y_mins * cos_val + y_shift
         | 
| 45 | 
            +
                y10 = x_maxs * sin_val + y_mins * cos_val + y_shift
         | 
| 46 | 
            +
                y11 = x_maxs * sin_val + y_maxs * cos_val + y_shift
         | 
| 47 | 
            +
                y01 = x_mins * sin_val + y_maxs * cos_val + y_shift
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
         | 
| 50 | 
            +
                ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
         | 
| 51 | 
            +
                ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
         | 
| 52 | 
            +
                ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
         | 
| 53 | 
            +
                ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                if boxes.ndim == 4:
         | 
| 56 | 
            +
                    ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    ret_boxes = boxes.copy()
         | 
| 59 | 
            +
                    ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                if not return_rotated_boxes:
         | 
| 62 | 
            +
                    return ret_boxes
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    return ret_boxes, rotated_boxes
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
            def rotate_boxes_wrt_centers(boxes, angle, scale=1, degrees=True,  
         | 
| 68 | 
            +
                                         return_rotated_boxes=False):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                Args:
         | 
| 71 | 
            +
                    boxes: (N, 4+K)
         | 
| 72 | 
            +
                    angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 73 | 
            +
                    scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 74 | 
            +
                        scale factor in x and y dimension
         | 
| 75 | 
            +
                    degrees: bool
         | 
| 76 | 
            +
                    return_rotated_boxes: bool
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                boxes = np.asarray(boxes, np.float32)
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                angle = np.asarray(angle, np.float32)
         | 
| 81 | 
            +
                scale = np.asarray(scale, np.float32)
         | 
| 82 | 
            +
                angle = assert_and_normalize_shape(angle, boxes.shape[0])
         | 
| 83 | 
            +
                scale = assert_and_normalize_shape(scale, boxes.shape[0])
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                if degrees:
         | 
| 86 | 
            +
                    angle = np.deg2rad(angle)
         | 
| 87 | 
            +
                cos_val = scale * np.cos(angle)
         | 
| 88 | 
            +
                sin_val = scale * np.sin(angle)
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                x_centers = boxes[:, 2] + boxes[:, 0]
         | 
| 91 | 
            +
                y_centers = boxes[:, 3] + boxes[:, 1]
         | 
| 92 | 
            +
                x_centers *= 0.5
         | 
| 93 | 
            +
                y_centers *= 0.5
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                half_widths = boxes[:, 2] - boxes[:, 0]
         | 
| 96 | 
            +
                half_heights = boxes[:, 3] - boxes[:, 1]
         | 
| 97 | 
            +
                half_widths *= 0.5
         | 
| 98 | 
            +
                half_heights *= 0.5
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                half_widths_cos = half_widths * cos_val
         | 
| 101 | 
            +
                half_widths_sin = half_widths * sin_val
         | 
| 102 | 
            +
                half_heights_cos = half_heights * cos_val
         | 
| 103 | 
            +
                half_heights_sin = half_heights * sin_val
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                x00 = -half_widths_cos + half_heights_sin
         | 
| 106 | 
            +
                x10 = half_widths_cos + half_heights_sin
         | 
| 107 | 
            +
                x11 = half_widths_cos - half_heights_sin
         | 
| 108 | 
            +
                x01 = -half_widths_cos - half_heights_sin
         | 
| 109 | 
            +
                x00 += x_centers
         | 
| 110 | 
            +
                x10 += x_centers
         | 
| 111 | 
            +
                x11 += x_centers
         | 
| 112 | 
            +
                x01 += x_centers
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                y00 = -half_widths_sin - half_heights_cos
         | 
| 115 | 
            +
                y10 = half_widths_sin - half_heights_cos
         | 
| 116 | 
            +
                y11 = half_widths_sin + half_heights_cos
         | 
| 117 | 
            +
                y01 = -half_widths_sin + half_heights_cos
         | 
| 118 | 
            +
                y00 += y_centers
         | 
| 119 | 
            +
                y10 += y_centers
         | 
| 120 | 
            +
                y11 += y_centers
         | 
| 121 | 
            +
                y01 += y_centers
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
         | 
| 124 | 
            +
                ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
         | 
| 125 | 
            +
                ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
         | 
| 126 | 
            +
                ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
         | 
| 127 | 
            +
                ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                if boxes.ndim == 4:
         | 
| 130 | 
            +
                    ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
         | 
| 131 | 
            +
                else:
         | 
| 132 | 
            +
                    ret_boxes = boxes.copy()
         | 
| 133 | 
            +
                    ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                if not return_rotated_boxes:
         | 
| 136 | 
            +
                    return ret_boxes
         | 
| 137 | 
            +
                else:
         | 
| 138 | 
            +
                    return ret_boxes, rotated_boxes
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                
         | 
    	
        khandy/boxes/boxes_transform_scale.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from .boxes_utils import assert_and_normalize_shape
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def scale_boxes(boxes, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
         | 
| 6 | 
            +
                """Scale boxes coordinates in x and y dimensions.
         | 
| 7 | 
            +
                
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    boxes: (N, 4+K)
         | 
| 10 | 
            +
                    x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 11 | 
            +
                        scale factor in x dimension
         | 
| 12 | 
            +
                    y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 13 | 
            +
                        scale factor in y dimension
         | 
| 14 | 
            +
                    x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 15 | 
            +
                    y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                References:
         | 
| 18 | 
            +
                    `core.box_list_ops.scale` in TensorFlow object detection API
         | 
| 19 | 
            +
                    `utils.box_list_ops.scale` in TensorFlow object detection API
         | 
| 20 | 
            +
                    `datasets.pipelines.Resize._resize_bboxes` in mmdetection
         | 
| 21 | 
            +
                    `core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
         | 
| 22 | 
            +
                    `layers.mask_ops.scale_boxes` in detectron2
         | 
| 23 | 
            +
                    `mmcv.bbox_scaling`
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                x_scale = np.asarray(x_scale, np.float32)
         | 
| 28 | 
            +
                y_scale = np.asarray(y_scale, np.float32)
         | 
| 29 | 
            +
                x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
         | 
| 30 | 
            +
                y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 33 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 34 | 
            +
                x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
         | 
| 35 | 
            +
                y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                x_shift = 1 - x_scale
         | 
| 38 | 
            +
                y_shift = 1 - y_scale
         | 
| 39 | 
            +
                x_shift *= x_center
         | 
| 40 | 
            +
                y_shift *= y_center
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                boxes[:, 0] *= x_scale
         | 
| 43 | 
            +
                boxes[:, 1] *= y_scale
         | 
| 44 | 
            +
                boxes[:, 2] *= x_scale
         | 
| 45 | 
            +
                boxes[:, 3] *= y_scale
         | 
| 46 | 
            +
                boxes[:, 0] += x_shift
         | 
| 47 | 
            +
                boxes[:, 1] += y_shift
         | 
| 48 | 
            +
                boxes[:, 2] += x_shift
         | 
| 49 | 
            +
                boxes[:, 3] += y_shift
         | 
| 50 | 
            +
                return boxes
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
            def scale_boxes_wrt_centers(boxes, x_scale=1, y_scale=1, copy=True):
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                Args:
         | 
| 56 | 
            +
                    boxes: (N, 4+K)
         | 
| 57 | 
            +
                    x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 58 | 
            +
                        scale factor in x dimension
         | 
| 59 | 
            +
                    y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 60 | 
            +
                        scale factor in y dimension
         | 
| 61 | 
            +
                        
         | 
| 62 | 
            +
                References:
         | 
| 63 | 
            +
                    `core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
         | 
| 64 | 
            +
                    `layers.mask_ops.scale_boxes` in detectron2
         | 
| 65 | 
            +
                    `mmcv.bbox_scaling`
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                x_scale = np.asarray(x_scale, np.float32)
         | 
| 70 | 
            +
                y_scale = np.asarray(y_scale, np.float32)
         | 
| 71 | 
            +
                x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
         | 
| 72 | 
            +
                y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
         | 
| 73 | 
            +
                
         | 
| 74 | 
            +
                x_factor = (x_scale - 1) * 0.5
         | 
| 75 | 
            +
                y_factor = (y_scale - 1) * 0.5
         | 
| 76 | 
            +
                x_deltas = boxes[:, 2] - boxes[:, 0]
         | 
| 77 | 
            +
                y_deltas = boxes[:, 3] - boxes[:, 1]
         | 
| 78 | 
            +
                x_deltas *= x_factor
         | 
| 79 | 
            +
                y_deltas *= y_factor
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                boxes[:, 0] -= x_deltas
         | 
| 82 | 
            +
                boxes[:, 1] -= y_deltas
         | 
| 83 | 
            +
                boxes[:, 2] += x_deltas
         | 
| 84 | 
            +
                boxes[:, 3] += y_deltas
         | 
| 85 | 
            +
                return boxes
         | 
| 86 | 
            +
             | 
    	
        khandy/boxes/boxes_transform_translate.py
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from .boxes_utils import assert_and_normalize_shape
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def translate_boxes(boxes, x_shift=0, y_shift=0, copy=True):
         | 
| 6 | 
            +
                """translate boxes coordinates in x and y dimensions.
         | 
| 7 | 
            +
                
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    boxes: (N, 4+K)
         | 
| 10 | 
            +
                    x_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 11 | 
            +
                        shift in x dimension
         | 
| 12 | 
            +
                    y_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 13 | 
            +
                        shift in y dimension
         | 
| 14 | 
            +
                    copy: bool
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                References:
         | 
| 17 | 
            +
                    `datasets.pipelines.RandomCrop` in mmdetection
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                x_shift = np.asarray(x_shift, np.float32)
         | 
| 22 | 
            +
                y_shift = np.asarray(y_shift, np.float32)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                x_shift = assert_and_normalize_shape(x_shift, boxes.shape[0])
         | 
| 25 | 
            +
                y_shift = assert_and_normalize_shape(y_shift, boxes.shape[0])
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                boxes[:, 0] += x_shift
         | 
| 28 | 
            +
                boxes[:, 1] += y_shift
         | 
| 29 | 
            +
                boxes[:, 2] += x_shift
         | 
| 30 | 
            +
                boxes[:, 3] += y_shift
         | 
| 31 | 
            +
                return boxes
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
            def adjust_boxes(boxes, x_min_shift, y_min_shift, x_max_shift, y_max_shift, copy=True):
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                    boxes: (N, 4+K)
         | 
| 38 | 
            +
                    x_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 39 | 
            +
                        shift (x_min, y_min) in x dimension
         | 
| 40 | 
            +
                    y_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 41 | 
            +
                        shift (x_min, y_min) in y dimension
         | 
| 42 | 
            +
                    x_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 43 | 
            +
                        shift (x_max, y_max) in x dimension
         | 
| 44 | 
            +
                    y_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 45 | 
            +
                        shift (x_max, y_max) in y dimension
         | 
| 46 | 
            +
                    copy: bool
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                x_min_shift = np.asarray(x_min_shift, np.float32)
         | 
| 51 | 
            +
                y_min_shift = np.asarray(y_min_shift, np.float32)
         | 
| 52 | 
            +
                x_max_shift = np.asarray(x_max_shift, np.float32)
         | 
| 53 | 
            +
                y_max_shift = np.asarray(y_max_shift, np.float32)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                x_min_shift = assert_and_normalize_shape(x_min_shift, boxes.shape[0])
         | 
| 56 | 
            +
                y_min_shift = assert_and_normalize_shape(y_min_shift, boxes.shape[0])
         | 
| 57 | 
            +
                x_max_shift = assert_and_normalize_shape(x_max_shift, boxes.shape[0])
         | 
| 58 | 
            +
                y_max_shift = assert_and_normalize_shape(y_max_shift, boxes.shape[0])
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                boxes[:, 0] += x_min_shift
         | 
| 61 | 
            +
                boxes[:, 1] += y_min_shift
         | 
| 62 | 
            +
                boxes[:, 2] += x_max_shift
         | 
| 63 | 
            +
                boxes[:, 3] += y_max_shift
         | 
| 64 | 
            +
                return boxes
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
            def inflate_or_deflate_boxes(boxes, width_delta=0, height_delta=0, copy=True):
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                Args:
         | 
| 70 | 
            +
                    boxes: (N, 4+K)
         | 
| 71 | 
            +
                    width_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 72 | 
            +
                    height_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
         | 
| 73 | 
            +
                    copy: bool
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                width_delta = np.asarray(width_delta, np.float32)
         | 
| 78 | 
            +
                height_delta = np.asarray(height_delta, np.float32)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                width_delta = assert_and_normalize_shape(width_delta, boxes.shape[0])
         | 
| 81 | 
            +
                height_delta = assert_and_normalize_shape(height_delta, boxes.shape[0])
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                half_width_delta = width_delta * 0.5
         | 
| 84 | 
            +
                half_height_delta = height_delta * 0.5
         | 
| 85 | 
            +
                boxes[:, 0] -= half_width_delta
         | 
| 86 | 
            +
                boxes[:, 1] -= half_height_delta
         | 
| 87 | 
            +
                boxes[:, 2] += half_width_delta
         | 
| 88 | 
            +
                boxes[:, 3] += half_height_delta
         | 
| 89 | 
            +
                return boxes
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def inflate_boxes_to_square(boxes, copy=True):
         | 
| 93 | 
            +
                """Inflate boxes to square
         | 
| 94 | 
            +
                Args:
         | 
| 95 | 
            +
                    boxes: (N, 4+K)
         | 
| 96 | 
            +
                    copy: bool
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                widths = boxes[:, 2] - boxes[:, 0]
         | 
| 101 | 
            +
                heights = boxes[:, 3] - boxes[:, 1]
         | 
| 102 | 
            +
                max_side_lengths = np.maximum(widths, heights)
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                width_deltas = np.subtract(max_side_lengths, widths, widths)
         | 
| 105 | 
            +
                height_deltas = np.subtract(max_side_lengths, heights, heights)
         | 
| 106 | 
            +
                width_deltas *= 0.5
         | 
| 107 | 
            +
                height_deltas *= 0.5
         | 
| 108 | 
            +
                boxes[:, 0] -= width_deltas
         | 
| 109 | 
            +
                boxes[:, 1] -= height_deltas
         | 
| 110 | 
            +
                boxes[:, 2] += width_deltas
         | 
| 111 | 
            +
                boxes[:, 3] += height_deltas
         | 
| 112 | 
            +
                return boxes
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            def deflate_boxes_to_square(boxes, copy=True):
         | 
| 116 | 
            +
                """Deflate boxes to square
         | 
| 117 | 
            +
                Args:
         | 
| 118 | 
            +
                    boxes: (N, 4+K)
         | 
| 119 | 
            +
                    copy: bool
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                boxes = np.array(boxes, dtype=np.float32, copy=copy)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                widths = boxes[:, 2] - boxes[:, 0]
         | 
| 124 | 
            +
                heights = boxes[:, 3] - boxes[:, 1]
         | 
| 125 | 
            +
                min_side_lengths = np.minimum(widths, heights)
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                width_deltas = np.subtract(min_side_lengths, widths, widths)
         | 
| 128 | 
            +
                height_deltas = np.subtract(min_side_lengths, heights, heights)
         | 
| 129 | 
            +
                width_deltas *= 0.5
         | 
| 130 | 
            +
                height_deltas *= 0.5
         | 
| 131 | 
            +
                boxes[:, 0] -= width_deltas
         | 
| 132 | 
            +
                boxes[:, 1] -= height_deltas
         | 
| 133 | 
            +
                boxes[:, 2] += width_deltas
         | 
| 134 | 
            +
                boxes[:, 3] += height_deltas
         | 
| 135 | 
            +
                return boxes
         | 
| 136 | 
            +
             | 
    	
        khandy/boxes/boxes_utils.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def assert_and_normalize_shape(x, length):
         | 
| 5 | 
            +
                """
         | 
| 6 | 
            +
                Args:
         | 
| 7 | 
            +
                    x: ndarray
         | 
| 8 | 
            +
                    length: int
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                if x.ndim == 0:
         | 
| 11 | 
            +
                    return x
         | 
| 12 | 
            +
                elif x.ndim == 1:
         | 
| 13 | 
            +
                    if len(x) == 1:
         | 
| 14 | 
            +
                        return x
         | 
| 15 | 
            +
                    elif len(x) == length:
         | 
| 16 | 
            +
                        return x
         | 
| 17 | 
            +
                    else:
         | 
| 18 | 
            +
                        raise ValueError('Incompatible shape!')
         | 
| 19 | 
            +
                elif x.ndim == 2:
         | 
| 20 | 
            +
                    if x.shape == (1, 1):
         | 
| 21 | 
            +
                        return np.squeeze(x, axis=-1)
         | 
| 22 | 
            +
                    elif x.shape == (length, 1):
         | 
| 23 | 
            +
                        return np.squeeze(x, axis=-1)
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        raise ValueError('Incompatible shape!') 
         | 
| 26 | 
            +
                else:
         | 
| 27 | 
            +
                    raise ValueError('Incompatible ndim!')
         | 
| 28 | 
            +
                    
         | 
    	
        khandy/dict_utils.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            from collections import OrderedDict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def get_dict_first_item(dict_obj):
         | 
| 6 | 
            +
                for key in dict_obj:
         | 
| 7 | 
            +
                    return key, dict_obj[key]
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def sort_dict(dict_obj, key=None, reverse=False):
         | 
| 11 | 
            +
                return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def create_multidict(key_list, value_list):
         | 
| 15 | 
            +
                assert len(key_list) == len(value_list)
         | 
| 16 | 
            +
                multidict_obj = {}
         | 
| 17 | 
            +
                for key, value in zip(key_list, value_list):
         | 
| 18 | 
            +
                    multidict_obj.setdefault(key, []).append(value)
         | 
| 19 | 
            +
                return multidict_obj
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def convert_multidict_to_list(multidict_obj):
         | 
| 23 | 
            +
                key_list, value_list = [], []
         | 
| 24 | 
            +
                for key, value in multidict_obj.items():
         | 
| 25 | 
            +
                    key_list += [key] * len(value)
         | 
| 26 | 
            +
                    value_list += value
         | 
| 27 | 
            +
                return key_list, value_list
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def convert_multidict_to_records(multidict_obj, key_map=None, raise_if_key_error=True):
         | 
| 31 | 
            +
                records = []
         | 
| 32 | 
            +
                if key_map is None:
         | 
| 33 | 
            +
                    for key in multidict_obj:
         | 
| 34 | 
            +
                        for value in multidict_obj[key]:
         | 
| 35 | 
            +
                            records.append('{},{}'.format(value, key))
         | 
| 36 | 
            +
                else:
         | 
| 37 | 
            +
                    for key in multidict_obj:
         | 
| 38 | 
            +
                        if raise_if_key_error:
         | 
| 39 | 
            +
                            mapped_key = key_map[key]
         | 
| 40 | 
            +
                        else:
         | 
| 41 | 
            +
                            mapped_key = key_map.get(key, key)
         | 
| 42 | 
            +
                        for value in multidict_obj[key]:
         | 
| 43 | 
            +
                            records.append('{},{}'.format(value, mapped_key))
         | 
| 44 | 
            +
                return records
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
            def sample_multidict(multidict_obj, num_keys, num_per_key=None):
         | 
| 48 | 
            +
                num_keys = min(num_keys, len(multidict_obj))
         | 
| 49 | 
            +
                sub_keys = random.sample(list(multidict_obj), num_keys)
         | 
| 50 | 
            +
                if num_per_key is None:
         | 
| 51 | 
            +
                    sub_mdict = {key: multidict_obj[key] for key in sub_keys}
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                    sub_mdict = {}
         | 
| 54 | 
            +
                    for key in sub_keys:
         | 
| 55 | 
            +
                        num_examples_inner = min(num_per_key, len(multidict_obj[key]))
         | 
| 56 | 
            +
                        sub_mdict[key] = random.sample(multidict_obj[key], num_examples_inner)
         | 
| 57 | 
            +
                return sub_mdict
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
            def split_multidict_on_key(multidict_obj, split_ratio, use_shuffle=False):
         | 
| 61 | 
            +
                """Split multidict_obj on its key.
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                assert isinstance(multidict_obj, dict)
         | 
| 64 | 
            +
                assert isinstance(split_ratio, (list, tuple))
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                pdf = [k / float(sum(split_ratio)) for k in split_ratio]
         | 
| 67 | 
            +
                cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
         | 
| 68 | 
            +
                indices = [int(round(len(multidict_obj) * k)) for k in cdf]
         | 
| 69 | 
            +
                dict_keys = list(multidict_obj)
         | 
| 70 | 
            +
                if use_shuffle: 
         | 
| 71 | 
            +
                    random.shuffle(dict_keys)
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                be_split_list = []
         | 
| 74 | 
            +
                for i in range(len(split_ratio)):
         | 
| 75 | 
            +
                    part_keys = dict_keys[indices[i]: indices[i + 1]]
         | 
| 76 | 
            +
                    part_dict = dict([(key, multidict_obj[key]) for key in part_keys])
         | 
| 77 | 
            +
                    be_split_list.append(part_dict)
         | 
| 78 | 
            +
                return be_split_list
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
            def split_multidict_on_value(multidict_obj, split_ratio, use_shuffle=False):
         | 
| 82 | 
            +
                """Split multidict_obj on its value.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                assert isinstance(multidict_obj, dict)
         | 
| 85 | 
            +
                assert isinstance(split_ratio, (list, tuple))
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                pdf = [k / float(sum(split_ratio)) for k in split_ratio]
         | 
| 88 | 
            +
                cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
         | 
| 89 | 
            +
                be_split_list = [dict() for k in range(len(split_ratio))] 
         | 
| 90 | 
            +
                for key, value in multidict_obj.items():
         | 
| 91 | 
            +
                    indices = [int(round(len(value) * k)) for k in cdf]
         | 
| 92 | 
            +
                    cloned = value[:]
         | 
| 93 | 
            +
                    if use_shuffle: 
         | 
| 94 | 
            +
                        random.shuffle(cloned)
         | 
| 95 | 
            +
                    for i in range(len(split_ratio)):
         | 
| 96 | 
            +
                        be_split_list[i][key] = cloned[indices[i]: indices[i + 1]]
         | 
| 97 | 
            +
                return be_split_list
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
            def get_multidict_info(multidict_obj, with_print=False, desc=None):
         | 
| 101 | 
            +
                num_list = [len(val) for val in multidict_obj.values()]
         | 
| 102 | 
            +
                num_keys = len(num_list)
         | 
| 103 | 
            +
                num_values = sum(num_list)
         | 
| 104 | 
            +
                max_values_per_key = max(num_list)
         | 
| 105 | 
            +
                min_values_per_key = min(num_list)
         | 
| 106 | 
            +
                if num_keys == 0:
         | 
| 107 | 
            +
                    avg_values_per_key = 0
         | 
| 108 | 
            +
                else:
         | 
| 109 | 
            +
                    avg_values_per_key = num_values / num_keys
         | 
| 110 | 
            +
                info = {
         | 
| 111 | 
            +
                    'num_keys': num_keys,
         | 
| 112 | 
            +
                    'num_values': num_values,
         | 
| 113 | 
            +
                    'max_values_per_key': max_values_per_key,
         | 
| 114 | 
            +
                    'min_values_per_key': min_values_per_key,
         | 
| 115 | 
            +
                    'avg_values_per_key': avg_values_per_key,
         | 
| 116 | 
            +
                }
         | 
| 117 | 
            +
                if with_print:
         | 
| 118 | 
            +
                    desc = desc or '<unknown>'
         | 
| 119 | 
            +
                    print('{} key number:    {}'.format(desc, info['num_keys']))
         | 
| 120 | 
            +
                    print('{} value number:    {}'.format(desc, info['num_values']))
         | 
| 121 | 
            +
                    print('{} max number per-key: {}'.format(desc, info['max_values_per_key']))
         | 
| 122 | 
            +
                    print('{} min number per-key: {}'.format(desc, info['min_values_per_key']))
         | 
| 123 | 
            +
                    print('{} avg number per-key: {:.2f}'.format(desc, info['avg_values_per_key']))
         | 
| 124 | 
            +
                return info
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            def filter_multidict_by_number(multidict_obj, lower, upper=None):
         | 
| 128 | 
            +
                if upper is None:
         | 
| 129 | 
            +
                    return {key: value for key, value in multidict_obj.items() 
         | 
| 130 | 
            +
                            if lower <= len(value) }
         | 
| 131 | 
            +
                else:
         | 
| 132 | 
            +
                    assert lower <= upper, 'lower must not be greater than upper'
         | 
| 133 | 
            +
                    return {key: value for key, value in multidict_obj.items() 
         | 
| 134 | 
            +
                            if lower <= len(value) <= upper }
         | 
| 135 | 
            +
                    
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
            def sort_multidict_by_number(multidict_obj, num_keys_to_keep=None, reverse=True):
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                Args:
         | 
| 140 | 
            +
                    reverse: sort in ascending order when is True.
         | 
| 141 | 
            +
                """
         | 
| 142 | 
            +
                if num_keys_to_keep is None: 
         | 
| 143 | 
            +
                    num_keys_to_keep = len(multidict_obj)
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    num_keys_to_keep = min(num_keys_to_keep, len(multidict_obj))
         | 
| 146 | 
            +
                sorted_items = sorted(multidict_obj.items(), key=lambda x: len(x[1]), reverse=reverse)
         | 
| 147 | 
            +
                filtered_dict = OrderedDict()
         | 
| 148 | 
            +
                for i in range(num_keys_to_keep):
         | 
| 149 | 
            +
                    filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
         | 
| 150 | 
            +
                return filtered_dict
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                
         | 
| 153 | 
            +
            def merge_multidict(*mdicts):
         | 
| 154 | 
            +
                merged_multidict = {}
         | 
| 155 | 
            +
                for item in mdicts:
         | 
| 156 | 
            +
                    for key, value in item.items():
         | 
| 157 | 
            +
                        merged_multidict.setdefault(key, []).extend(value)
         | 
| 158 | 
            +
                return merged_multidict
         | 
| 159 | 
            +
                
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
            def invert_multidict(multidict_obj):
         | 
| 162 | 
            +
                inverted_dict = {}
         | 
| 163 | 
            +
                for key, value in multidict_obj.items():
         | 
| 164 | 
            +
                    for item in value:
         | 
| 165 | 
            +
                        inverted_dict.setdefault(item, []).append(key)
         | 
| 166 | 
            +
                return inverted_dict
         | 
| 167 | 
            +
                
         | 
| 168 | 
            +
                
         | 
    	
        khandy/draw_utils.py
    ADDED
    
    | @@ -0,0 +1,148 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import PIL
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            from PIL import ImageDraw
         | 
| 5 | 
            +
            from PIL import ImageFont
         | 
| 6 | 
            +
            from PIL import ImageColor
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def _is_legal_color(color):
         | 
| 10 | 
            +
                if color is None:
         | 
| 11 | 
            +
                    return True
         | 
| 12 | 
            +
                if isinstance(color, str):
         | 
| 13 | 
            +
                    return True
         | 
| 14 | 
            +
                return isinstance(color, (tuple, list)) and len(color) == 3
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _normalize_color(color, pil_mode, swap_rgb=False):
         | 
| 18 | 
            +
                if color is None:
         | 
| 19 | 
            +
                    return color
         | 
| 20 | 
            +
                if isinstance(color, str):
         | 
| 21 | 
            +
                    color = ImageColor.getrgb(color)
         | 
| 22 | 
            +
                gray = color[0]
         | 
| 23 | 
            +
                if swap_rgb:
         | 
| 24 | 
            +
                    color = (color[2], color[1], color[0])
         | 
| 25 | 
            +
                if pil_mode == 'L':
         | 
| 26 | 
            +
                    color = gray
         | 
| 27 | 
            +
                return color
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
            def draw_text(image, text, position, color=(255,0,0), font=None, font_size=15):
         | 
| 31 | 
            +
                """Draws text on given image.
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                Args:
         | 
| 34 | 
            +
                    image (ndarray).
         | 
| 35 | 
            +
                    text (str): text to be drawn.
         | 
| 36 | 
            +
                    position (Tuple[int, int]): position where to be drawn.
         | 
| 37 | 
            +
                    color (List[Union[str, Tuple[int, int, int]]]): text color.
         | 
| 38 | 
            +
                    font (str):  A filename or file-like object containing a TrueType font. If the file is not found in this 
         | 
| 39 | 
            +
                        filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
         | 
| 40 | 
            +
                        or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
         | 
| 41 | 
            +
                    font_size (int): The requested font size in points.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                References:
         | 
| 44 | 
            +
                    torchvision.utils.draw_bounding_boxes
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 47 | 
            +
                    # PIL.Image.fromarray fails with uint16 arrays
         | 
| 48 | 
            +
                    # https://github.com/python-pillow/Pillow/issues/1514
         | 
| 49 | 
            +
                    if (image.dtype == np.uint16) and (image.ndim != 2):
         | 
| 50 | 
            +
                        image = (image / 256).astype(np.uint8)
         | 
| 51 | 
            +
                    pil_image = Image.fromarray(image)
         | 
| 52 | 
            +
                elif isinstance(image, PIL.Image.Image):
         | 
| 53 | 
            +
                    pil_image = image
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    raise TypeError('Unsupported image type!')
         | 
| 56 | 
            +
                assert pil_image.mode in ['L', 'RGB', 'RGBA']
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                assert _is_legal_color(color)
         | 
| 59 | 
            +
                color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                if font is None:
         | 
| 62 | 
            +
                    font_object = ImageFont.load_default()
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    font_object = ImageFont.truetype(font, size=font_size)
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                draw = ImageDraw.Draw(pil_image)
         | 
| 67 | 
            +
                draw.text((position[0], position[1]), text, 
         | 
| 68 | 
            +
                          fill=color, font=font_object)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 71 | 
            +
                    return np.asarray(pil_image)
         | 
| 72 | 
            +
                return pil_image
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def draw_bounding_boxes(image, boxes, labels=None, colors=None,
         | 
| 76 | 
            +
                                    fill=False, width=1, font=None, font_size=15):
         | 
| 77 | 
            +
                """Draws bounding boxes on given image.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                Args:
         | 
| 80 | 
            +
                    image (ndarray).
         | 
| 81 | 
            +
                    boxes (ndarray): ndarray of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
         | 
| 82 | 
            +
                    labels (List[str]): List containing the labels of bounding boxes.
         | 
| 83 | 
            +
                    colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes or labels.
         | 
| 84 | 
            +
                    fill (bool): If `True` fills the bounding box with specified color.
         | 
| 85 | 
            +
                    width (int): Width of bounding box.
         | 
| 86 | 
            +
                    font (str):  A filename or file-like object containing a TrueType font. If the file is not found in this 
         | 
| 87 | 
            +
                        filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
         | 
| 88 | 
            +
                        or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
         | 
| 89 | 
            +
                    font_size (int): The requested font size in points.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                References:
         | 
| 92 | 
            +
                    torchvision.utils.draw_bounding_boxes
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 95 | 
            +
                    # PIL.Image.fromarray fails with uint16 arrays
         | 
| 96 | 
            +
                    # https://github.com/python-pillow/Pillow/issues/1514
         | 
| 97 | 
            +
                    if (image.dtype == np.uint16) and (image.ndim != 2):
         | 
| 98 | 
            +
                        image = (image / 256).astype(np.uint8)
         | 
| 99 | 
            +
                    pil_image = Image.fromarray(image)
         | 
| 100 | 
            +
                elif isinstance(image, PIL.Image.Image):
         | 
| 101 | 
            +
                    pil_image = image
         | 
| 102 | 
            +
                else:
         | 
| 103 | 
            +
                    raise TypeError('Unsupported image type!')
         | 
| 104 | 
            +
                pil_image = pil_image.convert('RGB')
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                if font is None:
         | 
| 107 | 
            +
                    font_object = ImageFont.load_default()
         | 
| 108 | 
            +
                else:
         | 
| 109 | 
            +
                    font_object = ImageFont.truetype(font, size=font_size)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                if fill:
         | 
| 112 | 
            +
                    draw = ImageDraw.Draw(pil_image, "RGBA")
         | 
| 113 | 
            +
                else:
         | 
| 114 | 
            +
                    draw = ImageDraw.Draw(pil_image)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                for i, bbox in enumerate(boxes):
         | 
| 117 | 
            +
                    if colors is None:
         | 
| 118 | 
            +
                        color = None
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        color = colors[i]
         | 
| 121 | 
            +
                        
         | 
| 122 | 
            +
                    assert _is_legal_color(color)
         | 
| 123 | 
            +
                    color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
         | 
| 124 | 
            +
                    
         | 
| 125 | 
            +
                    if fill:
         | 
| 126 | 
            +
                        if color is None:
         | 
| 127 | 
            +
                            fill_color = (255, 255, 255, 100)
         | 
| 128 | 
            +
                        elif isinstance(color, str):
         | 
| 129 | 
            +
                            # This will automatically raise Error if rgb cannot be parsed.
         | 
| 130 | 
            +
                            fill_color = ImageColor.getrgb(color) + (100,)
         | 
| 131 | 
            +
                        elif isinstance(color, tuple):
         | 
| 132 | 
            +
                            fill_color = color + (100,)
         | 
| 133 | 
            +
                        # the first argument of ImageDraw.rectangle:
         | 
| 134 | 
            +
                        # in old version only supports [(x0, y0), (x1, y1)]
         | 
| 135 | 
            +
                        # in new version supports either [(x0, y0), (x1, y1)] or [x0, y0, x1, y1]
         | 
| 136 | 
            +
                        draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color, fill=fill_color)
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if labels is not None:
         | 
| 141 | 
            +
                        margin = width + 1
         | 
| 142 | 
            +
                        draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=font_object)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if isinstance(image, np.ndarray):
         | 
| 145 | 
            +
                    return np.asarray(pil_image)
         | 
| 146 | 
            +
                return pil_image
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                
         | 
    	
        khandy/feature_utils.py
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from collections import OrderedDict
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import khandy
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def convert_feature_dict_to_array(feature_dict):
         | 
| 8 | 
            +
                one_feature = khandy.get_dict_first_item(feature_dict)[1]
         | 
| 9 | 
            +
                num_features = sum([len(item) for item in feature_dict.values()])
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                key_list = []
         | 
| 12 | 
            +
                start_index = 0
         | 
| 13 | 
            +
                feature_array = np.empty((num_features, one_feature.shape[-1]), one_feature.dtype)
         | 
| 14 | 
            +
                for key, value in feature_dict.items():
         | 
| 15 | 
            +
                    feature_array[start_index: start_index + len(value)]= value
         | 
| 16 | 
            +
                    key_list += [key] * len(value)
         | 
| 17 | 
            +
                    start_index += len(value)
         | 
| 18 | 
            +
                return key_list, feature_array
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def convert_feature_array_to_dict(key_list, feature_array):
         | 
| 22 | 
            +
                assert len(key_list) == len(feature_array)
         | 
| 23 | 
            +
                feature_dict = OrderedDict()
         | 
| 24 | 
            +
                for key, feat in zip(key_list, feature_array):
         | 
| 25 | 
            +
                    feature_dict.setdefault(key, []).append(feat)
         | 
| 26 | 
            +
                for label in feature_dict.keys():
         | 
| 27 | 
            +
                    feature_dict[label] = np.vstack(feature_dict[label])
         | 
| 28 | 
            +
                return feature_dict
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
            def pairwise_distances(x, y, squared=True):
         | 
| 32 | 
            +
                """Compute pairwise (squared) Euclidean distances.
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                References:
         | 
| 35 | 
            +
                    [2016 CVPR] Deep Metric Learning via Lifted Structured Feature Embedding
         | 
| 36 | 
            +
                    `euclidean_distances` from sklearn
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                assert isinstance(x, np.ndarray) and x.ndim == 2
         | 
| 39 | 
            +
                assert isinstance(y, np.ndarray) and y.ndim == 2
         | 
| 40 | 
            +
                assert x.shape[1] == y.shape[1]
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                x_square = np.expand_dims(np.einsum('ij,ij->i', x, x), axis=1)
         | 
| 43 | 
            +
                if x is y:
         | 
| 44 | 
            +
                    y_square = x_square.T
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    y_square = np.expand_dims(np.einsum('ij,ij->i', y, y), axis=0)
         | 
| 47 | 
            +
                distances = np.dot(x, y.T)
         | 
| 48 | 
            +
                # use inplace operation to accelerate
         | 
| 49 | 
            +
                distances *= -2
         | 
| 50 | 
            +
                distances += x_square
         | 
| 51 | 
            +
                distances += y_square
         | 
| 52 | 
            +
                # result maybe less than 0 due to floating point rounding errors.
         | 
| 53 | 
            +
                np.maximum(distances, 0, distances)
         | 
| 54 | 
            +
                if x is y:
         | 
| 55 | 
            +
                    # Ensure that distances between vectors and themselves are set to 0.0.
         | 
| 56 | 
            +
                    # This may not be the case due to floating point rounding errors.
         | 
| 57 | 
            +
                    distances.flat[::distances.shape[0] + 1] = 0.0
         | 
| 58 | 
            +
                if not squared:
         | 
| 59 | 
            +
                    np.sqrt(distances, distances)
         | 
| 60 | 
            +
                return distances
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                
         | 
    	
        khandy/file_io_utils.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import numbers
         | 
| 4 | 
            +
            import pickle
         | 
| 5 | 
            +
            import warnings
         | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def load_list(filename, encoding='utf-8', start=0, stop=None):
         | 
| 10 | 
            +
                assert isinstance(start, numbers.Integral) and start >= 0
         | 
| 11 | 
            +
                assert (stop is None) or (isinstance(stop, numbers.Integral) and stop > start)
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                lines = []
         | 
| 14 | 
            +
                with open(filename, 'r', encoding=encoding) as f:
         | 
| 15 | 
            +
                    for _ in range(start):
         | 
| 16 | 
            +
                        f.readline()
         | 
| 17 | 
            +
                    for k, line in enumerate(f):
         | 
| 18 | 
            +
                        if (stop is not None) and (k + start > stop):
         | 
| 19 | 
            +
                            break
         | 
| 20 | 
            +
                        lines.append(line.rstrip('\n'))
         | 
| 21 | 
            +
                return lines
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def save_list(filename, list_obj, encoding='utf-8', append_break=True):
         | 
| 25 | 
            +
                with open(filename, 'w', encoding=encoding) as f:
         | 
| 26 | 
            +
                    if append_break:
         | 
| 27 | 
            +
                        for item in list_obj:
         | 
| 28 | 
            +
                            f.write(str(item) + '\n')
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        for item in list_obj:
         | 
| 31 | 
            +
                            f.write(str(item))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def load_json(filename, encoding='utf-8'):
         | 
| 35 | 
            +
                with open(filename, 'r', encoding=encoding) as f:
         | 
| 36 | 
            +
                    data = json.load(f, object_pairs_hook=OrderedDict)
         | 
| 37 | 
            +
                return data
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def save_json(filename, data, encoding='utf-8', indent=4, cls=None, sort_keys=False):
         | 
| 41 | 
            +
                if not filename.endswith('.json'):
         | 
| 42 | 
            +
                    filename = filename + '.json'
         | 
| 43 | 
            +
                with open(filename, 'w', encoding=encoding) as f:
         | 
| 44 | 
            +
                    json.dump(data, f, indent=indent, separators=(',',': '),
         | 
| 45 | 
            +
                              ensure_ascii=False, cls=cls, sort_keys=sort_keys)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def load_bytes(filename, use_base64: bool = False) -> bytes:
         | 
| 49 | 
            +
                """Open the file in bytes mode, read it, and close the file.
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                References:
         | 
| 52 | 
            +
                    pathlib.Path.read_bytes
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                with open(filename, 'rb') as f:
         | 
| 55 | 
            +
                    data = f.read()
         | 
| 56 | 
            +
                if use_base64:
         | 
| 57 | 
            +
                    data = base64.b64encode(data)
         | 
| 58 | 
            +
                return data
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def save_bytes(filename, data: bytes, use_base64: bool = False) -> int:
         | 
| 62 | 
            +
                """Open the file in bytes mode, write to it, and close the file.
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                References:
         | 
| 65 | 
            +
                    pathlib.Path.write_bytes
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                if use_base64:
         | 
| 68 | 
            +
                    data = base64.b64decode(data)
         | 
| 69 | 
            +
                with open(filename, 'wb') as f:
         | 
| 70 | 
            +
                    ret = f.write(data)
         | 
| 71 | 
            +
                return ret
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def load_as_base64(filename) -> bytes:
         | 
| 75 | 
            +
                warnings.warn('khandy.load_as_base64 will be deprecated, use khandy.load_bytes instead!')
         | 
| 76 | 
            +
                return load_bytes(filename, True)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def load_object(filename):
         | 
| 80 | 
            +
                with open(filename, 'rb') as f:
         | 
| 81 | 
            +
                    return pickle.load(f)
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
            def save_object(filename, obj):
         | 
| 85 | 
            +
                with open(filename, 'wb') as f:
         | 
| 86 | 
            +
                    pickle.dump(obj, f)
         | 
| 87 | 
            +
                    
         | 
    	
        khandy/fs_utils.py
    ADDED
    
    | @@ -0,0 +1,375 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import shutil
         | 
| 4 | 
            +
            import warnings
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_path_stem(path):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                References:
         | 
| 10 | 
            +
                    `std::filesystem::path::stem` since C++17
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                return os.path.splitext(os.path.basename(path))[0]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def replace_path_stem(path, new_stem):
         | 
| 16 | 
            +
                dirname, basename = os.path.split(path)
         | 
| 17 | 
            +
                stem, extension = os.path.splitext(basename)
         | 
| 18 | 
            +
                if isinstance(new_stem, str):
         | 
| 19 | 
            +
                    return os.path.join(dirname, new_stem + extension)
         | 
| 20 | 
            +
                elif hasattr(new_stem, '__call__'):
         | 
| 21 | 
            +
                    return os.path.join(dirname, new_stem(stem) + extension)
         | 
| 22 | 
            +
                else:
         | 
| 23 | 
            +
                    raise TypeError('Unsupported Type!')
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            def get_path_extension(path):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                References:
         | 
| 29 | 
            +
                    `std::filesystem::path::extension` since C++17
         | 
| 30 | 
            +
                    
         | 
| 31 | 
            +
                Notes:
         | 
| 32 | 
            +
                    Not fully consistent with `std::filesystem::path::extension`
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                return os.path.splitext(os.path.basename(path))[1]
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def replace_path_extension(path, new_extension=None):
         | 
| 38 | 
            +
                """Replaces the extension with new_extension or removes it when the default value is used.
         | 
| 39 | 
            +
                Firstly, if this path has an extension, it is removed. Then, a dot character is appended 
         | 
| 40 | 
            +
                to the pathname, if new_extension is not empty or does not begin with a dot character.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                References:
         | 
| 43 | 
            +
                    `std::filesystem::path::replace_extension` since C++17
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                filename_wo_ext = os.path.splitext(path)[0]
         | 
| 46 | 
            +
                if new_extension == '' or new_extension is None:
         | 
| 47 | 
            +
                    return filename_wo_ext
         | 
| 48 | 
            +
                elif new_extension.startswith('.'):
         | 
| 49 | 
            +
                    return ''.join([filename_wo_ext, new_extension]) 
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    return '.'.join([filename_wo_ext, new_extension])
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def normalize_extension(extension):
         | 
| 55 | 
            +
                if extension.startswith('.'):
         | 
| 56 | 
            +
                    new_extension = extension.lower()
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    new_extension =  '.' + extension.lower()
         | 
| 59 | 
            +
                return new_extension
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def is_path_in_extensions(path, extensions):
         | 
| 63 | 
            +
                if isinstance(extensions, str):
         | 
| 64 | 
            +
                    extensions = [extensions]
         | 
| 65 | 
            +
                extensions = [normalize_extension(item) for item in extensions]
         | 
| 66 | 
            +
                extension = get_path_extension(path)
         | 
| 67 | 
            +
                return extension.lower() in extensions
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def normalize_path(path, norm_case=True):
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                References:
         | 
| 73 | 
            +
                    https://en.cppreference.com/w/cpp/filesystem/canonical
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                # On Unix and Windows, return the argument with an initial 
         | 
| 76 | 
            +
                # component of ~ or ~user replaced by that user's home directory.
         | 
| 77 | 
            +
                path = os.path.expanduser(path)
         | 
| 78 | 
            +
                # Return a normalized absolutized version of the pathname path. 
         | 
| 79 | 
            +
                # On most platforms, this is equivalent to calling the function 
         | 
| 80 | 
            +
                # normpath() as follows: normpath(join(os.getcwd(), path)).
         | 
| 81 | 
            +
                path = os.path.abspath(path)
         | 
| 82 | 
            +
                if norm_case:
         | 
| 83 | 
            +
                    # Normalize the case of a pathname. On Windows, 
         | 
| 84 | 
            +
                    # convert all characters in the pathname to lowercase, 
         | 
| 85 | 
            +
                    # and also convert forward slashes to backward slashes. 
         | 
| 86 | 
            +
                    # On other operating systems, return the path unchanged.
         | 
| 87 | 
            +
                    path = os.path.normcase(path)
         | 
| 88 | 
            +
                return path
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            def makedirs(name, mode=0o755):
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
                References:
         | 
| 94 | 
            +
                    mmcv.mkdir_or_exist
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                warnings.warn('`makedirs` will be deprecated!')
         | 
| 97 | 
            +
                if name == '':
         | 
| 98 | 
            +
                    return
         | 
| 99 | 
            +
                name = os.path.expanduser(name)
         | 
| 100 | 
            +
                os.makedirs(name, mode=mode, exist_ok=True)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def listdirs(paths, path_sep=None, full_path=True):
         | 
| 104 | 
            +
                """Enhancement on `os.listdir`
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                warnings.warn('`listdirs` will be deprecated!')
         | 
| 107 | 
            +
                assert isinstance(paths, (str, tuple, list))
         | 
| 108 | 
            +
                if isinstance(paths, str):
         | 
| 109 | 
            +
                    path_sep = path_sep or os.path.pathsep
         | 
| 110 | 
            +
                    paths = paths.split(path_sep)
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                all_filenames = []
         | 
| 113 | 
            +
                for path in paths:
         | 
| 114 | 
            +
                    path_ex = os.path.expanduser(path)
         | 
| 115 | 
            +
                    filenames = os.listdir(path_ex)
         | 
| 116 | 
            +
                    if full_path:
         | 
| 117 | 
            +
                        filenames = [os.path.join(path_ex, filename) for filename in filenames]
         | 
| 118 | 
            +
                    all_filenames.extend(filenames)
         | 
| 119 | 
            +
                return all_filenames
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def get_all_filenames(path, extensions=None, is_valid_file=None):
         | 
| 123 | 
            +
                warnings.warn('`get_all_filenames` will be deprecated, use `list_files_in_dir` with `recursive=True` instead!')
         | 
| 124 | 
            +
                if (extensions is not None) and (is_valid_file is not None):
         | 
| 125 | 
            +
                    raise ValueError("Both extensions and is_valid_file cannot "
         | 
| 126 | 
            +
                                     "be not None at the same time")
         | 
| 127 | 
            +
                if is_valid_file is None:
         | 
| 128 | 
            +
                    if extensions is not None:
         | 
| 129 | 
            +
                        def is_valid_file(filename):
         | 
| 130 | 
            +
                            return is_path_in_extensions(filename, extensions)
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        def is_valid_file(filename):
         | 
| 133 | 
            +
                            return True
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                all_filenames = []
         | 
| 136 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 137 | 
            +
                for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
         | 
| 138 | 
            +
                    for filename in sorted(filenames):
         | 
| 139 | 
            +
                        fullname = os.path.join(root, filename)
         | 
| 140 | 
            +
                        if is_valid_file(fullname):
         | 
| 141 | 
            +
                            all_filenames.append(fullname)
         | 
| 142 | 
            +
                return all_filenames
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def get_top_level_dirs(path, full_path=True):
         | 
| 146 | 
            +
                warnings.warn('`get_top_level_dirs` will be deprecated, use `list_dirs_in_dir` instead!')
         | 
| 147 | 
            +
                if path is None:
         | 
| 148 | 
            +
                    path = os.getcwd()
         | 
| 149 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 150 | 
            +
                filenames = os.listdir(path_ex)
         | 
| 151 | 
            +
                if full_path:
         | 
| 152 | 
            +
                    return [os.path.join(path_ex, item) for item in filenames
         | 
| 153 | 
            +
                            if os.path.isdir(os.path.join(path_ex, item))]
         | 
| 154 | 
            +
                else:
         | 
| 155 | 
            +
                    return [item for item in filenames
         | 
| 156 | 
            +
                            if os.path.isdir(os.path.join(path_ex, item))]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def get_top_level_files(path, full_path=True):
         | 
| 160 | 
            +
                warnings.warn('`get_top_level_files` will be deprecated, use `list_files_in_dir` instead!')
         | 
| 161 | 
            +
                if path is None:
         | 
| 162 | 
            +
                    path = os.getcwd()
         | 
| 163 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 164 | 
            +
                filenames = os.listdir(path_ex)
         | 
| 165 | 
            +
                if full_path:
         | 
| 166 | 
            +
                    return [os.path.join(path_ex, item) for item in filenames
         | 
| 167 | 
            +
                            if os.path.isfile(os.path.join(path_ex, item))]
         | 
| 168 | 
            +
                else:
         | 
| 169 | 
            +
                    return [item for item in filenames
         | 
| 170 | 
            +
                            if os.path.isfile(os.path.join(path_ex, item))]
         | 
| 171 | 
            +
                            
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            def list_items_in_dir(path=None, recursive=False, full_path=True):
         | 
| 174 | 
            +
                """List all entries in directory
         | 
| 175 | 
            +
                """
         | 
| 176 | 
            +
                if path is None:
         | 
| 177 | 
            +
                    path = os.getcwd()
         | 
| 178 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 179 | 
            +
                
         | 
| 180 | 
            +
                if not recursive:
         | 
| 181 | 
            +
                    names = os.listdir(path_ex)
         | 
| 182 | 
            +
                    if full_path:
         | 
| 183 | 
            +
                        return [os.path.join(path_ex, name) for name in sorted(names)]
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        return sorted(names)
         | 
| 186 | 
            +
                else:
         | 
| 187 | 
            +
                    all_names = []
         | 
| 188 | 
            +
                    for root, dirnames, filenames in sorted(os.walk(path_ex, followlinks=True)):
         | 
| 189 | 
            +
                        all_names += [os.path.join(root, name) for name in sorted(dirnames)]
         | 
| 190 | 
            +
                        all_names += [os.path.join(root, name) for name in sorted(filenames)]
         | 
| 191 | 
            +
                    return all_names
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            def list_dirs_in_dir(path=None, recursive=False, full_path=True):
         | 
| 195 | 
            +
                """List all dirs in directory
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
                if path is None:
         | 
| 198 | 
            +
                    path = os.getcwd()
         | 
| 199 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                if not recursive:
         | 
| 202 | 
            +
                    names = os.listdir(path_ex)
         | 
| 203 | 
            +
                    if full_path:
         | 
| 204 | 
            +
                        return [os.path.join(path_ex, name) for name in sorted(names)
         | 
| 205 | 
            +
                                if os.path.isdir(os.path.join(path_ex, name))]
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        return [name for name in sorted(names)
         | 
| 208 | 
            +
                                if os.path.isdir(os.path.join(path_ex, name))]
         | 
| 209 | 
            +
                else:
         | 
| 210 | 
            +
                    all_names = []
         | 
| 211 | 
            +
                    for root, dirnames, _ in sorted(os.walk(path_ex, followlinks=True)):
         | 
| 212 | 
            +
                        all_names += [os.path.join(root, name) for name in sorted(dirnames)]
         | 
| 213 | 
            +
                    return all_names
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def list_files_in_dir(path=None, recursive=False, full_path=True):
         | 
| 217 | 
            +
                """List all files in directory
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                if path is None:
         | 
| 220 | 
            +
                    path = os.getcwd()
         | 
| 221 | 
            +
                path_ex = os.path.expanduser(path)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                if not recursive:
         | 
| 224 | 
            +
                    names = os.listdir(path_ex)
         | 
| 225 | 
            +
                    if full_path:
         | 
| 226 | 
            +
                        return [os.path.join(path_ex, name) for name in sorted(names)
         | 
| 227 | 
            +
                                if os.path.isfile(os.path.join(path_ex, name))]
         | 
| 228 | 
            +
                    else:
         | 
| 229 | 
            +
                        return [name for name in sorted(names)
         | 
| 230 | 
            +
                                if os.path.isfile(os.path.join(path_ex, name))]
         | 
| 231 | 
            +
                else:
         | 
| 232 | 
            +
                    all_names = []
         | 
| 233 | 
            +
                    for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
         | 
| 234 | 
            +
                        all_names += [os.path.join(root, name) for name in sorted(filenames)]
         | 
| 235 | 
            +
                    return all_names
         | 
| 236 | 
            +
                    
         | 
| 237 | 
            +
             | 
| 238 | 
            +
            def get_folder_size(dirname):
         | 
| 239 | 
            +
                if not os.path.exists(dirname):
         | 
| 240 | 
            +
                    raise ValueError("Incorrect path: {}".format(dirname))
         | 
| 241 | 
            +
                total_size = 0
         | 
| 242 | 
            +
                for root, _, filenames in os.walk(dirname):
         | 
| 243 | 
            +
                    for name in filenames:
         | 
| 244 | 
            +
                        total_size += os.path.getsize(os.path.join(root, name))
         | 
| 245 | 
            +
                return total_size
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                
         | 
| 248 | 
            +
            def escape_filename(filename, new_char='_'):
         | 
| 249 | 
            +
                assert isinstance(new_char, str)
         | 
| 250 | 
            +
                control_chars = ''.join((map(chr, range(0x00, 0x20))))
         | 
| 251 | 
            +
                pattern = r'[\\/*?:"<>|{}]'.format(control_chars)
         | 
| 252 | 
            +
                return re.sub(pattern, new_char, filename)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            def replace_invalid_filename_char(filename, new_char='_'):
         | 
| 256 | 
            +
                warnings.warn('`replace_invalid_filename_char` will be deprecated, use `escape_filename` instead!')
         | 
| 257 | 
            +
                return escape_filename(filename, new_char)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            def copy_file(src, dst_dir, action_if_exist='rename'):
         | 
| 261 | 
            +
                """
         | 
| 262 | 
            +
                Args:
         | 
| 263 | 
            +
                    src: source file path
         | 
| 264 | 
            +
                    dst_dir: dest dir
         | 
| 265 | 
            +
                    action_if_exist: 
         | 
| 266 | 
            +
                        None: same as shutil.copy
         | 
| 267 | 
            +
                        ignore: when dest file exists, don't copy and return None
         | 
| 268 | 
            +
                        rename: when dest file exists, copy after rename
         | 
| 269 | 
            +
                        
         | 
| 270 | 
            +
                Returns:
         | 
| 271 | 
            +
                    dest filename
         | 
| 272 | 
            +
                """
         | 
| 273 | 
            +
                dst = os.path.join(dst_dir, os.path.basename(src))
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                if action_if_exist is None:
         | 
| 276 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 277 | 
            +
                    shutil.copy(src, dst)
         | 
| 278 | 
            +
                elif action_if_exist.lower() == 'ignore':
         | 
| 279 | 
            +
                    if os.path.exists(dst):
         | 
| 280 | 
            +
                        warnings.warn(f'{dst} already exists, do not copy!')
         | 
| 281 | 
            +
                        return dst
         | 
| 282 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 283 | 
            +
                    shutil.copy(src, dst)
         | 
| 284 | 
            +
                elif action_if_exist.lower() == 'rename':
         | 
| 285 | 
            +
                    suffix = 2
         | 
| 286 | 
            +
                    stem, extension = os.path.splitext(os.path.basename(src))
         | 
| 287 | 
            +
                    while os.path.exists(dst):
         | 
| 288 | 
            +
                        dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
         | 
| 289 | 
            +
                        suffix += 1
         | 
| 290 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 291 | 
            +
                    shutil.copy(src, dst)
         | 
| 292 | 
            +
                else:
         | 
| 293 | 
            +
                    raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
         | 
| 294 | 
            +
                    
         | 
| 295 | 
            +
                return dst
         | 
| 296 | 
            +
                
         | 
| 297 | 
            +
                
         | 
| 298 | 
            +
            def move_file(src, dst_dir, action_if_exist='rename'):
         | 
| 299 | 
            +
                """
         | 
| 300 | 
            +
                Args:
         | 
| 301 | 
            +
                    src: source file path
         | 
| 302 | 
            +
                    dst_dir: dest dir
         | 
| 303 | 
            +
                    action_if_exist: 
         | 
| 304 | 
            +
                        None: same as shutil.move
         | 
| 305 | 
            +
                        ignore: when dest file exists, don't move and return None
         | 
| 306 | 
            +
                        rename: when dest file exists, move after rename
         | 
| 307 | 
            +
                        
         | 
| 308 | 
            +
                Returns:
         | 
| 309 | 
            +
                    dest filename
         | 
| 310 | 
            +
                """
         | 
| 311 | 
            +
                dst = os.path.join(dst_dir, os.path.basename(src))
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                if action_if_exist is None:
         | 
| 314 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 315 | 
            +
                    shutil.move(src, dst)
         | 
| 316 | 
            +
                elif action_if_exist.lower() == 'ignore':
         | 
| 317 | 
            +
                    if os.path.exists(dst):
         | 
| 318 | 
            +
                        warnings.warn(f'{dst} already exists, do not move!')
         | 
| 319 | 
            +
                        return dst
         | 
| 320 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 321 | 
            +
                    shutil.move(src, dst)
         | 
| 322 | 
            +
                elif action_if_exist.lower() == 'rename':
         | 
| 323 | 
            +
                    suffix = 2
         | 
| 324 | 
            +
                    stem, extension = os.path.splitext(os.path.basename(src))
         | 
| 325 | 
            +
                    while os.path.exists(dst):
         | 
| 326 | 
            +
                        dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
         | 
| 327 | 
            +
                        suffix += 1
         | 
| 328 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 329 | 
            +
                    shutil.move(src, dst)
         | 
| 330 | 
            +
                else:
         | 
| 331 | 
            +
                    raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
         | 
| 332 | 
            +
                    
         | 
| 333 | 
            +
                return dst
         | 
| 334 | 
            +
                
         | 
| 335 | 
            +
                
         | 
| 336 | 
            +
            def rename_file(src, dst, action_if_exist='rename'):
         | 
| 337 | 
            +
                """
         | 
| 338 | 
            +
                Args:
         | 
| 339 | 
            +
                    src: source file path
         | 
| 340 | 
            +
                    dst: dest file path
         | 
| 341 | 
            +
                    action_if_exist: 
         | 
| 342 | 
            +
                        None: same as os.rename
         | 
| 343 | 
            +
                        ignore: when dest file exists, don't rename and return None
         | 
| 344 | 
            +
                        rename: when dest file exists, rename it
         | 
| 345 | 
            +
                        
         | 
| 346 | 
            +
                Returns:
         | 
| 347 | 
            +
                    dest filename
         | 
| 348 | 
            +
                """
         | 
| 349 | 
            +
                if dst == src:
         | 
| 350 | 
            +
                    return dst
         | 
| 351 | 
            +
                dst_dir = os.path.dirname(os.path.abspath(dst))
         | 
| 352 | 
            +
                
         | 
| 353 | 
            +
                if action_if_exist is None:
         | 
| 354 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 355 | 
            +
                    os.rename(src, dst)
         | 
| 356 | 
            +
                elif action_if_exist.lower() == 'ignore':
         | 
| 357 | 
            +
                    if os.path.exists(dst):
         | 
| 358 | 
            +
                        warnings.warn(f'{dst} already exists, do not rename!')
         | 
| 359 | 
            +
                        return dst
         | 
| 360 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 361 | 
            +
                    os.rename(src, dst)
         | 
| 362 | 
            +
                elif action_if_exist.lower() == 'rename':
         | 
| 363 | 
            +
                    suffix = 2
         | 
| 364 | 
            +
                    stem, extension = os.path.splitext(os.path.basename(dst))
         | 
| 365 | 
            +
                    while os.path.exists(dst):
         | 
| 366 | 
            +
                        dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
         | 
| 367 | 
            +
                        suffix += 1
         | 
| 368 | 
            +
                    os.makedirs(dst_dir, exist_ok=True)
         | 
| 369 | 
            +
                    os.rename(src, dst)
         | 
| 370 | 
            +
                else:
         | 
| 371 | 
            +
                    raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
         | 
| 372 | 
            +
                    
         | 
| 373 | 
            +
                return dst
         | 
| 374 | 
            +
                
         | 
| 375 | 
            +
                
         | 
    	
        khandy/hash_utils.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import hashlib
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def calc_hash(content, hash_object=None):
         | 
| 5 | 
            +
                hash_object = hash_object or hashlib.md5()
         | 
| 6 | 
            +
                if isinstance(hash_object, str):
         | 
| 7 | 
            +
                    hash_object = hashlib.new(hash_object)
         | 
| 8 | 
            +
                hash_object.update(content)
         | 
| 9 | 
            +
                return hash_object.hexdigest()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def calc_file_hash(filename, hash_object=None, chunk_size=1024 * 1024):
         | 
| 13 | 
            +
                hash_object = hash_object or hashlib.md5()
         | 
| 14 | 
            +
                if isinstance(hash_object, str):
         | 
| 15 | 
            +
                    hash_object = hashlib.new(hash_object)
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
                with open(filename, "rb") as f:
         | 
| 18 | 
            +
                    while True:
         | 
| 19 | 
            +
                        chunk = f.read(chunk_size)
         | 
| 20 | 
            +
                        if not chunk:
         | 
| 21 | 
            +
                            break
         | 
| 22 | 
            +
                        hash_object.update(chunk)
         | 
| 23 | 
            +
                return hash_object.hexdigest()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                
         | 
    	
        khandy/image/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .align_and_crop import *
         | 
| 2 | 
            +
            from .crop_or_pad import *
         | 
| 3 | 
            +
            from .flip import *
         | 
| 4 | 
            +
            from .image_hash import *
         | 
| 5 | 
            +
            from .resize import *
         | 
| 6 | 
            +
            from .rotate import *
         | 
| 7 | 
            +
            from .translate import *
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .misc import *
         | 
| 10 | 
            +
             | 
    	
        khandy/image/align_and_crop.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def get_similarity_transform(src_pts, dst_pts):
         | 
| 6 | 
            +
                """Get similarity transform matrix from src_pts to dst_pts
         | 
| 7 | 
            +
                
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    src_pts: Kx2 np.array
         | 
| 10 | 
            +
                        source points matrix, each row is a pair of coordinates (x, y)
         | 
| 11 | 
            +
                    dst_pts: Kx2 np.array
         | 
| 12 | 
            +
                        destination points matrix, each row is a pair of coordinates (x, y)
         | 
| 13 | 
            +
                        
         | 
| 14 | 
            +
                Returns:
         | 
| 15 | 
            +
                    xform_matrix: 3x3 np.array
         | 
| 16 | 
            +
                        transform matrix from src_pts to dst_pts
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                src_pts = np.asarray(src_pts)
         | 
| 19 | 
            +
                dst_pts = np.asarray(dst_pts)
         | 
| 20 | 
            +
                assert src_pts.shape == dst_pts.shape
         | 
| 21 | 
            +
                assert (src_pts.ndim == 2) and (src_pts.shape[-1] == 2)
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                npts = src_pts.shape[0]
         | 
| 24 | 
            +
                src_x = src_pts[:, 0].reshape((-1, 1))
         | 
| 25 | 
            +
                src_y = src_pts[:, 1].reshape((-1, 1))
         | 
| 26 | 
            +
                tmp1 = np.hstack((src_x, -src_y, np.ones((npts, 1)), np.zeros((npts, 1))))
         | 
| 27 | 
            +
                tmp2 = np.hstack((src_y, src_x, np.zeros((npts, 1)), np.ones((npts, 1))))
         | 
| 28 | 
            +
                A = np.vstack((tmp1, tmp2))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                dst_x = dst_pts[:, 0].reshape((-1, 1))
         | 
| 31 | 
            +
                dst_y = dst_pts[:, 1].reshape((-1, 1))
         | 
| 32 | 
            +
                b = np.vstack((dst_x, dst_y))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                x = np.linalg.lstsq(A, b, rcond=-1)[0]
         | 
| 35 | 
            +
                x = np.squeeze(x)
         | 
| 36 | 
            +
                sc, ss, tx, ty = x[0], x[1], x[2], x[3]
         | 
| 37 | 
            +
                xform_matrix = np.array([
         | 
| 38 | 
            +
                    [sc, -ss, tx],
         | 
| 39 | 
            +
                    [ss,  sc, ty],
         | 
| 40 | 
            +
                    [ 0,   0,  1]
         | 
| 41 | 
            +
                ])
         | 
| 42 | 
            +
                return xform_matrix
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
            def align_and_crop(image, landmarks, std_landmarks, align_size, 
         | 
| 46 | 
            +
                               border_value=0, return_transform_matrix=False):
         | 
| 47 | 
            +
                landmarks = np.asarray(landmarks)
         | 
| 48 | 
            +
                std_landmarks = np.asarray(std_landmarks)
         | 
| 49 | 
            +
                xform_matrix = get_similarity_transform(landmarks, std_landmarks)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                landmarks_ex = np.pad(landmarks, ((0,0),(0,1)), mode='constant', constant_values=1)
         | 
| 52 | 
            +
                dst_landmarks = np.dot(landmarks_ex, xform_matrix[:2,:].T)
         | 
| 53 | 
            +
                dst_image = cv2.warpAffine(image, xform_matrix[:2,:], dsize=align_size, 
         | 
| 54 | 
            +
                                           borderValue=border_value)
         | 
| 55 | 
            +
                if return_transform_matrix:
         | 
| 56 | 
            +
                    return dst_image, dst_landmarks, xform_matrix
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    return dst_image, dst_landmarks
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    
         | 
    	
        khandy/image/crop_or_pad.py
    ADDED
    
    | @@ -0,0 +1,138 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numbers
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import khandy
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def crop(image, x_min, y_min, x_max, y_max, border_value=0):
         | 
| 9 | 
            +
                """Crop the given image at specified rectangular area.
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                See Also:
         | 
| 12 | 
            +
                    translate_image
         | 
| 13 | 
            +
                    
         | 
| 14 | 
            +
                References:
         | 
| 15 | 
            +
                    PIL.Image.crop
         | 
| 16 | 
            +
                    tf.image.resize_image_with_crop_or_pad
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 19 | 
            +
                assert isinstance(x_min, numbers.Integral) and isinstance(y_min, numbers.Integral)
         | 
| 20 | 
            +
                assert isinstance(x_max, numbers.Integral) and isinstance(y_max, numbers.Integral)
         | 
| 21 | 
            +
                assert (x_min <= x_max) and (y_min <= y_max)
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 24 | 
            +
                dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
         | 
| 25 | 
            +
                channels = 1 if image.ndim == 2 else image.shape[2]
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                if isinstance(border_value, (tuple, list)):
         | 
| 28 | 
            +
                    assert len(border_value) == channels, \
         | 
| 29 | 
            +
                        'Expected the num of elements in tuple equals the channels ' \
         | 
| 30 | 
            +
                        'of input image. Found {} vs {}'.format(
         | 
| 31 | 
            +
                            len(border_value), channels)
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    border_value = (border_value,) * channels
         | 
| 34 | 
            +
                dst_image = khandy.create_solid_color_image(
         | 
| 35 | 
            +
                    dst_width, dst_height, border_value, dtype=image.dtype)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                src_x_begin = max(x_min, 0)
         | 
| 38 | 
            +
                src_x_end   = min(x_max + 1, src_width)
         | 
| 39 | 
            +
                dst_x_begin = src_x_begin - x_min
         | 
| 40 | 
            +
                dst_x_end   = src_x_end - x_min
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                src_y_begin = max(y_min, 0)
         | 
| 43 | 
            +
                src_y_end   = min(y_max + 1, src_height)
         | 
| 44 | 
            +
                dst_y_begin = src_y_begin - y_min
         | 
| 45 | 
            +
                dst_y_end   = src_y_end - y_min
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                if (src_x_begin >= src_x_end) or (src_y_begin >= src_y_end):
         | 
| 48 | 
            +
                    return dst_image
         | 
| 49 | 
            +
                dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
         | 
| 50 | 
            +
                    image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
         | 
| 51 | 
            +
                return dst_image
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
         | 
| 55 | 
            +
                warnings.warn('crop_or_pad will be deprecated, use crop instead!')
         | 
| 56 | 
            +
                return crop(image, x_min, y_min, x_max, y_max, border_value)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def crop_coords(boxes, image_width, image_height):
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                References:
         | 
| 62 | 
            +
                    `mmcv.impad`
         | 
| 63 | 
            +
                    `pad` in https://github.com/kpzhang93/MTCNN_face_detection_alignment
         | 
| 64 | 
            +
                    `MtcnnDetector.pad` in https://github.com/AITTSMD/MTCNN-Tensorflow
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                x_mins = boxes[:, 0]
         | 
| 67 | 
            +
                y_mins = boxes[:, 1]
         | 
| 68 | 
            +
                x_maxs = boxes[:, 2]
         | 
| 69 | 
            +
                y_maxs = boxes[:, 3]
         | 
| 70 | 
            +
                dst_widths = x_maxs - x_mins + 1
         | 
| 71 | 
            +
                dst_heights = y_maxs - y_mins + 1
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                src_x_begin = np.maximum(x_mins, 0)
         | 
| 74 | 
            +
                src_x_end   = np.minimum(x_maxs + 1, image_width)
         | 
| 75 | 
            +
                dst_x_begin = src_x_begin - x_mins
         | 
| 76 | 
            +
                dst_x_end   = src_x_end - x_mins
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                src_y_begin = np.maximum(y_mins, 0)
         | 
| 79 | 
            +
                src_y_end   = np.minimum(y_maxs + 1, image_height)
         | 
| 80 | 
            +
                dst_y_begin = src_y_begin - y_mins
         | 
| 81 | 
            +
                dst_y_end   = src_y_end - y_mins
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                coords = np.stack([dst_y_begin, dst_y_end, dst_x_begin, dst_x_end, 
         | 
| 84 | 
            +
                                   src_y_begin, src_y_end, src_x_begin, src_x_end, 
         | 
| 85 | 
            +
                                   dst_heights, dst_widths], axis=0)
         | 
| 86 | 
            +
                return coords
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            def crop_or_pad_coords(boxes, image_width, image_height):
         | 
| 90 | 
            +
                warnings.warn('crop_or_pad_coords will be deprecated, use crop_coords instead!')
         | 
| 91 | 
            +
                return crop_coords(boxes, image_width, image_height)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def center_crop(image, dst_width, dst_height, strict=True):
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                strict: 
         | 
| 97 | 
            +
                    when True, raise error if src size is less than dst size. 
         | 
| 98 | 
            +
                    when False, remain unchanged if src size is less than dst size, otherwise center crop.
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 101 | 
            +
                assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
         | 
| 102 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 103 | 
            +
                if strict:
         | 
| 104 | 
            +
                    assert (src_height >= dst_height) and (src_width >= dst_width)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                crop_top = max((src_height - dst_height) // 2, 0)
         | 
| 107 | 
            +
                crop_left = max((src_width - dst_width) // 2, 0)
         | 
| 108 | 
            +
                cropped = image[crop_top: dst_height + crop_top, 
         | 
| 109 | 
            +
                                crop_left: dst_width + crop_left, ...]
         | 
| 110 | 
            +
                return cropped
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def center_pad(image, dst_width, dst_height, strict=True):
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                strict: 
         | 
| 116 | 
            +
                    when True, raise error if src size is greater than dst size. 
         | 
| 117 | 
            +
                    when False, remain unchanged if src size is greater than dst size, otherwise center pad.
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 120 | 
            +
                assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
         | 
| 121 | 
            +
                
         | 
| 122 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 123 | 
            +
                if strict:
         | 
| 124 | 
            +
                    assert (src_height <= dst_height) and (src_width <= dst_width)
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                padding_x = max(dst_width - src_width, 0)
         | 
| 127 | 
            +
                padding_y = max(dst_height - src_height, 0)
         | 
| 128 | 
            +
                padding_top = padding_y // 2
         | 
| 129 | 
            +
                padding_left = padding_x // 2
         | 
| 130 | 
            +
                if image.ndim == 2:
         | 
| 131 | 
            +
                    padding = ((padding_top, padding_y - padding_top), 
         | 
| 132 | 
            +
                               (padding_left, padding_x - padding_left))
         | 
| 133 | 
            +
                else:
         | 
| 134 | 
            +
                    padding = ((padding_top, padding_y - padding_top), 
         | 
| 135 | 
            +
                               (padding_left, padding_x - padding_left), (0, 0))
         | 
| 136 | 
            +
                return np.pad(image, padding, 'constant')
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                
         | 
    	
        khandy/image/flip.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import khandy
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def flip_image(image, direction='h', copy=True):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                References:
         | 
| 8 | 
            +
                    np.flipud, np.fliplr, np.flip
         | 
| 9 | 
            +
                    cv2.flip
         | 
| 10 | 
            +
                    tf.image.flip_up_down
         | 
| 11 | 
            +
                    tf.image.flip_left_right
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 14 | 
            +
                assert direction in ['x', 'h', 'horizontal',
         | 
| 15 | 
            +
                                     'y', 'v', 'vertical', 
         | 
| 16 | 
            +
                                     'o', 'b', 'both']
         | 
| 17 | 
            +
                if copy:
         | 
| 18 | 
            +
                    image = image.copy()
         | 
| 19 | 
            +
                if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
         | 
| 20 | 
            +
                    image = np.fliplr(image)
         | 
| 21 | 
            +
                if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
         | 
| 22 | 
            +
                    image = np.flipud(image)
         | 
| 23 | 
            +
                return image
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
            def transpose_image(image, copy=True):
         | 
| 27 | 
            +
                """Transpose image.
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                References:
         | 
| 30 | 
            +
                    np.transpose
         | 
| 31 | 
            +
                    cv2.transpose
         | 
| 32 | 
            +
                    tf.image.transpose
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 35 | 
            +
                if copy:
         | 
| 36 | 
            +
                    image = image.copy()
         | 
| 37 | 
            +
                if image.ndim == 2:
         | 
| 38 | 
            +
                    transpose_axes = (1, 0)
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                    transpose_axes = (1, 0, 2)
         | 
| 41 | 
            +
                image = np.transpose(image, transpose_axes)
         | 
| 42 | 
            +
                return image
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                
         | 
| 45 | 
            +
            def rot90_image(image, n=1, copy=True):
         | 
| 46 | 
            +
                """Rotate image counter-clockwise by 90 degrees.
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                References:
         | 
| 49 | 
            +
                    np.rot90
         | 
| 50 | 
            +
                    cv2.rotate
         | 
| 51 | 
            +
                    tf.image.rot90
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 54 | 
            +
                if copy:
         | 
| 55 | 
            +
                    image = image.copy()
         | 
| 56 | 
            +
                if image.ndim == 2:
         | 
| 57 | 
            +
                    transpose_axes = (1, 0)
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    transpose_axes = (1, 0, 2)
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                n = n % 4
         | 
| 62 | 
            +
                if n == 0:
         | 
| 63 | 
            +
                    return image[:]
         | 
| 64 | 
            +
                elif n == 1:
         | 
| 65 | 
            +
                    image = np.transpose(image, transpose_axes)
         | 
| 66 | 
            +
                    image = np.flipud(image)
         | 
| 67 | 
            +
                elif n == 2:
         | 
| 68 | 
            +
                    image = np.fliplr(np.flipud(image))
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    image = np.transpose(image, transpose_axes)
         | 
| 71 | 
            +
                    image = np.fliplr(image)
         | 
| 72 | 
            +
                return image
         | 
    	
        khandy/image/image_hash.py
    ADDED
    
    | @@ -0,0 +1,69 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import khandy
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def _convert_bool_matrix_to_int(bool_mat):
         | 
| 7 | 
            +
                hash_val = int(0)
         | 
| 8 | 
            +
                for item in bool_mat.flatten():
         | 
| 9 | 
            +
                    hash_val <<= 1
         | 
| 10 | 
            +
                    hash_val |= int(item)
         | 
| 11 | 
            +
                return hash_val
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
            def calc_image_ahash(image):
         | 
| 15 | 
            +
                """Average Hashing
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                References:
         | 
| 18 | 
            +
                    http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 21 | 
            +
                if image.ndim == 3:
         | 
| 22 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         | 
| 23 | 
            +
                resized = cv2.resize(image, (8, 8))
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                mean_val = np.mean(resized)
         | 
| 26 | 
            +
                hash_mat = resized >= mean_val
         | 
| 27 | 
            +
                hash_val = _convert_bool_matrix_to_int(hash_mat)
         | 
| 28 | 
            +
                return f'{hash_val:016x}'
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
            def calc_image_dhash(image):
         | 
| 32 | 
            +
                """Difference Hashing
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                References:
         | 
| 35 | 
            +
                    http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 38 | 
            +
                if image.ndim == 3:
         | 
| 39 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         | 
| 40 | 
            +
                resized = cv2.resize(image, (9, 8))
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                hash_mat = resized[:,:-1] >= resized[:,1:]
         | 
| 43 | 
            +
                hash_val = _convert_bool_matrix_to_int(hash_mat)
         | 
| 44 | 
            +
                return f'{hash_val:016x}'
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
            def calc_image_phash(image):
         | 
| 48 | 
            +
                """Perceptual Hashing
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                References:
         | 
| 51 | 
            +
                    http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 54 | 
            +
                if image.ndim == 3:
         | 
| 55 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         | 
| 56 | 
            +
                resized = cv2.resize(image, (32, 32))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                dct_coeff = cv2.dct(resized.astype(np.float32))
         | 
| 59 | 
            +
                reduced_dct_coeff = dct_coeff[:8, :8]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # # mean of coefficients excluding the DC term (0th term)
         | 
| 62 | 
            +
                # mean_val = np.mean(reduced_dct_coeff.flatten()[1:])
         | 
| 63 | 
            +
                # median of coefficients
         | 
| 64 | 
            +
                median_val = np.median(reduced_dct_coeff)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                hash_mat = reduced_dct_coeff >= median_val
         | 
| 67 | 
            +
                hash_val = _convert_bool_matrix_to_int(hash_mat)
         | 
| 68 | 
            +
                return f'{hash_val:016x}'
         | 
| 69 | 
            +
                
         | 
    	
        khandy/image/misc.py
    ADDED
    
    | @@ -0,0 +1,329 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import imghdr
         | 
| 3 | 
            +
            import numbers
         | 
| 4 | 
            +
            import warnings
         | 
| 5 | 
            +
            from io import BytesIO
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            import khandy
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def imread(file_or_buffer, flags=-1):
         | 
| 14 | 
            +
                """Improvement on cv2.imread, make it support filename including chinese character.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                try:
         | 
| 17 | 
            +
                    if isinstance(file_or_buffer, bytes):
         | 
| 18 | 
            +
                        return cv2.imdecode(np.frombuffer(file_or_buffer, dtype=np.uint8), flags)
         | 
| 19 | 
            +
                    else:
         | 
| 20 | 
            +
                        # support type: file or str or Path
         | 
| 21 | 
            +
                        return cv2.imdecode(np.fromfile(file_or_buffer, dtype=np.uint8), flags)
         | 
| 22 | 
            +
                except Exception as e:
         | 
| 23 | 
            +
                    print(e)
         | 
| 24 | 
            +
                    return None
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def imread_cv(file_or_buffer, flags=-1):
         | 
| 28 | 
            +
                warnings.warn('khandy.imread_cv will be deprecated, use khandy.imread instead!')
         | 
| 29 | 
            +
                return imread(file_or_buffer, flags)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def imwrite(filename, image, params=None):
         | 
| 33 | 
            +
                """Improvement on cv2.imwrite, make it support filename including chinese character.
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                cv2.imencode(os.path.splitext(filename)[-1], image, params)[1].tofile(filename)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def imwrite_cv(filename, image, params=None):
         | 
| 39 | 
            +
                warnings.warn('khandy.imwrite_cv will be deprecated, use khandy.imwrite instead!')
         | 
| 40 | 
            +
                return imwrite(filename, image, params)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def imread_pil(file_or_buffer, to_mode=None):
         | 
| 44 | 
            +
                """Improvement on Image.open to avoid ResourceWarning.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                try:
         | 
| 47 | 
            +
                    if isinstance(file_or_buffer, bytes):
         | 
| 48 | 
            +
                        buffer = BytesIO()
         | 
| 49 | 
            +
                        buffer.write(file_or_buffer)
         | 
| 50 | 
            +
                        buffer.seek(0)
         | 
| 51 | 
            +
                        file_or_buffer = buffer
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    if hasattr(file_or_buffer, 'read'):
         | 
| 54 | 
            +
                        image = Image.open(file_or_buffer)
         | 
| 55 | 
            +
                        if to_mode is not None:
         | 
| 56 | 
            +
                            image = image.convert(to_mode)
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
         | 
| 59 | 
            +
                        with open(file_or_buffer, 'rb') as f:
         | 
| 60 | 
            +
                            image = Image.open(f)
         | 
| 61 | 
            +
                            # If convert outside with statement, will raise "seek of closed file" as
         | 
| 62 | 
            +
                            # https://github.com/microsoft/Swin-Transformer/issues/66
         | 
| 63 | 
            +
                            if to_mode is not None:
         | 
| 64 | 
            +
                                image = image.convert(to_mode)
         | 
| 65 | 
            +
                    return image
         | 
| 66 | 
            +
                except Exception as e:
         | 
| 67 | 
            +
                    print(e)
         | 
| 68 | 
            +
                    return None
         | 
| 69 | 
            +
                    
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
            def imwrite_bytes(filename, image_bytes: bytes, update_extension: bool = True):
         | 
| 72 | 
            +
                """Write image bytes to file.
         | 
| 73 | 
            +
                
         | 
| 74 | 
            +
                Args:
         | 
| 75 | 
            +
                    filename: str
         | 
| 76 | 
            +
                        filename which image_bytes is written into.
         | 
| 77 | 
            +
                    image_bytes: bytes
         | 
| 78 | 
            +
                        image content to be written.
         | 
| 79 | 
            +
                    update_extension: bool
         | 
| 80 | 
            +
                        whether update extension according to image_bytes or not.
         | 
| 81 | 
            +
                        the cost of update extension is smaller than update image format.
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                extension = imghdr.what('', image_bytes)
         | 
| 84 | 
            +
                file_extension = khandy.get_path_extension(filename)
         | 
| 85 | 
            +
                # imghdr.what fails to determine image format sometimes!
         | 
| 86 | 
            +
                # so when its return value is None, never update extension.
         | 
| 87 | 
            +
                if extension is None:
         | 
| 88 | 
            +
                    image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
         | 
| 89 | 
            +
                    image_bytes = cv2.imencode(file_extension, image)[1]
         | 
| 90 | 
            +
                elif (extension.lower() != file_extension.lower()[1:]):
         | 
| 91 | 
            +
                    if update_extension:
         | 
| 92 | 
            +
                        filename = khandy.replace_path_extension(filename, extension)
         | 
| 93 | 
            +
                    else:
         | 
| 94 | 
            +
                        image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
         | 
| 95 | 
            +
                        image_bytes = cv2.imencode(file_extension, image)[1]
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                with open(filename, "wb") as f:
         | 
| 98 | 
            +
                    f.write(image_bytes)
         | 
| 99 | 
            +
                return filename
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def rescale_image(image: np.ndarray, rescale_factor='auto', dst_dtype=np.float32):
         | 
| 103 | 
            +
                """Rescale image by rescale_factor.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                Args:
         | 
| 106 | 
            +
                    img (ndarray): Image to be rescaled.
         | 
| 107 | 
            +
                    rescale_factor (str, int or float, *optional*, defaults to `'auto'`): 
         | 
| 108 | 
            +
                        rescale the image by the specified scale factor. When is `'auto'`, 
         | 
| 109 | 
            +
                        rescale the image to [0, 1).
         | 
| 110 | 
            +
                    dtype (np.dtype, *optional*, defaults to `np.float32`):
         | 
| 111 | 
            +
                        The dtype of the output image. Defaults to `np.float32`.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                Returns:
         | 
| 114 | 
            +
                    ndarray: The rescaled image.
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                if rescale_factor == 'auto':
         | 
| 117 | 
            +
                    if np.issubdtype(image.dtype, np.unsignedinteger):
         | 
| 118 | 
            +
                        rescale_factor = 1. / np.iinfo(image.dtype).max
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
         | 
| 121 | 
            +
                elif issubclass(rescale_factor, (int, float)):
         | 
| 122 | 
            +
                    pass
         | 
| 123 | 
            +
                else:
         | 
| 124 | 
            +
                    raise TypeError('rescale_factor must be "auto", int or float')
         | 
| 125 | 
            +
                image = image.astype(dst_dtype, copy=True)
         | 
| 126 | 
            +
                image *= rescale_factor
         | 
| 127 | 
            +
                image = image.astype(dst_dtype)
         | 
| 128 | 
            +
                return image
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def normalize_image_value(image: np.ndarray, mean, std, rescale_factor=None):
         | 
| 132 | 
            +
                """Normalize an image with mean and std, rescale optionally.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                Args:
         | 
| 135 | 
            +
                    image (ndarray): Image to be normalized.
         | 
| 136 | 
            +
                    mean (int, float, Sequence[int], Sequence[float], ndarray): The mean to be used for normalize.
         | 
| 137 | 
            +
                    std (int, float, Sequence[int], Sequence[float], ndarray): The std to be used for normalize.
         | 
| 138 | 
            +
                    rescale_factor (None, 'auto', int or float, *optional*, defaults to `None`): 
         | 
| 139 | 
            +
                        rescale the image by the specified scale factor. When is `'auto'`, 
         | 
| 140 | 
            +
                        rescale the image to [0, 1); When is `None`, do not rescale.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                Returns:
         | 
| 143 | 
            +
                    ndarray: The normalized image which dtype is np.float32.
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                dst_dtype = np.float32
         | 
| 146 | 
            +
                mean = np.array(mean, dtype=dst_dtype).flatten()
         | 
| 147 | 
            +
                std = np.array(std, dtype=dst_dtype).flatten()
         | 
| 148 | 
            +
                if rescale_factor == 'auto':
         | 
| 149 | 
            +
                    if np.issubdtype(image.dtype, np.unsignedinteger):
         | 
| 150 | 
            +
                        mean *= np.iinfo(image.dtype).max
         | 
| 151 | 
            +
                        std *= np.iinfo(image.dtype).max
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
         | 
| 154 | 
            +
                elif isinstance(rescale_factor, (int, float)):
         | 
| 155 | 
            +
                    mean *= rescale_factor
         | 
| 156 | 
            +
                    std *= rescale_factor
         | 
| 157 | 
            +
                image = image.astype(dst_dtype, copy=True)
         | 
| 158 | 
            +
                image -= mean
         | 
| 159 | 
            +
                image /= std
         | 
| 160 | 
            +
                return image
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def normalize_image_dtype(image, keep_num_channels=False):
         | 
| 164 | 
            +
                """Normalize image dtype to uint8 (usually for visualization).
         | 
| 165 | 
            +
                
         | 
| 166 | 
            +
                Args:
         | 
| 167 | 
            +
                    image : ndarray
         | 
| 168 | 
            +
                        Input image.
         | 
| 169 | 
            +
                    keep_num_channels : bool, optional
         | 
| 170 | 
            +
                        If this is set to True, the result is an array which has 
         | 
| 171 | 
            +
                        the same shape as input image, otherwise the result is 
         | 
| 172 | 
            +
                        an array whose channels number is 3.
         | 
| 173 | 
            +
                        
         | 
| 174 | 
            +
                Returns:
         | 
| 175 | 
            +
                    out: ndarray
         | 
| 176 | 
            +
                        Image whose dtype is np.uint8.
         | 
| 177 | 
            +
                """
         | 
| 178 | 
            +
                assert (image.ndim == 3 and image.shape[-1] in [1, 3]) or (image.ndim == 2)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                image = image.astype(np.float32)
         | 
| 181 | 
            +
                image = khandy.minmax_normalize(image, axis=None, copy=False)
         | 
| 182 | 
            +
                image = np.array(image * 255, dtype=np.uint8)
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                if not keep_num_channels:
         | 
| 185 | 
            +
                    if image.ndim == 2:
         | 
| 186 | 
            +
                        image = np.expand_dims(image, -1)
         | 
| 187 | 
            +
                    if image.shape[-1] == 1:
         | 
| 188 | 
            +
                        image = np.tile(image, (1,1,3))
         | 
| 189 | 
            +
                return image
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
                
         | 
| 192 | 
            +
            def normalize_image_channel(image, swap_rb=False):
         | 
| 193 | 
            +
                """Normalize image channel number and order to RGB or BGR.
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                Args:
         | 
| 196 | 
            +
                    image : ndarray
         | 
| 197 | 
            +
                        Input image.
         | 
| 198 | 
            +
                    swap_rb : bool, optional
         | 
| 199 | 
            +
                        whether swap red and blue channel or not
         | 
| 200 | 
            +
                        
         | 
| 201 | 
            +
                Returns:
         | 
| 202 | 
            +
                    out: ndarray
         | 
| 203 | 
            +
                        Image whose shape is (..., 3).
         | 
| 204 | 
            +
                """
         | 
| 205 | 
            +
                if image.ndim == 2:
         | 
| 206 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
         | 
| 207 | 
            +
                elif image.ndim == 3:
         | 
| 208 | 
            +
                    num_channels = image.shape[-1]
         | 
| 209 | 
            +
                    if num_channels == 1:
         | 
| 210 | 
            +
                        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
         | 
| 211 | 
            +
                    elif num_channels == 3:
         | 
| 212 | 
            +
                        if swap_rb:
         | 
| 213 | 
            +
                            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         | 
| 214 | 
            +
                    elif num_channels == 4:
         | 
| 215 | 
            +
                        if swap_rb:
         | 
| 216 | 
            +
                            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
         | 
| 217 | 
            +
                        else:
         | 
| 218 | 
            +
                            image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
         | 
| 219 | 
            +
                    else:
         | 
| 220 | 
            +
                        raise ValueError(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
         | 
| 221 | 
            +
                else:
         | 
| 222 | 
            +
                    raise ValueError(f'Unsupported image ndarray ndim, only support 2 and 3, got {image.ndim}!')
         | 
| 223 | 
            +
                return image
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            def normalize_image_shape(image, swap_rb=False):
         | 
| 227 | 
            +
                warnings.warn('khandy.normalize_image_shape will be deprecated, use khandy.normalize_image_channel instead!')
         | 
| 228 | 
            +
                return normalize_image_channel(image, swap_rb)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def stack_image_list(image_list, dtype=np.float32):
         | 
| 232 | 
            +
                """Join a sequence of image along a new axis before first axis.
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                References:
         | 
| 235 | 
            +
                    `im_list_to_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
                assert isinstance(image_list, (tuple, list))
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                max_dimension = np.array([image.ndim for image in image_list]).max()
         | 
| 240 | 
            +
                assert max_dimension in [2, 3]
         | 
| 241 | 
            +
                max_shape = np.array([image.shape[:2] for image in image_list]).max(axis=0)
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
                num_channels = []
         | 
| 244 | 
            +
                for image in image_list:
         | 
| 245 | 
            +
                    if image.ndim == 2:
         | 
| 246 | 
            +
                        num_channels.append(1)
         | 
| 247 | 
            +
                    else:
         | 
| 248 | 
            +
                        num_channels.append(image.shape[-1])
         | 
| 249 | 
            +
                assert len(set(num_channels) - set([1])) in [0, 1]
         | 
| 250 | 
            +
                max_num_channels = np.max(num_channels)
         | 
| 251 | 
            +
                
         | 
| 252 | 
            +
                blob = np.empty((len(image_list), max_shape[0], max_shape[1], max_num_channels), dtype=dtype)
         | 
| 253 | 
            +
                for k, image in enumerate(image_list):
         | 
| 254 | 
            +
                    blob[k, :image.shape[0], :image.shape[1], :] = np.atleast_3d(image).astype(dtype, copy=False)
         | 
| 255 | 
            +
                if max_dimension == 2:
         | 
| 256 | 
            +
                    blob = np.squeeze(blob, axis=-1)
         | 
| 257 | 
            +
                return blob
         | 
| 258 | 
            +
                
         | 
| 259 | 
            +
             | 
| 260 | 
            +
            def is_numpy_image(image):
         | 
| 261 | 
            +
                return isinstance(image, np.ndarray) and image.ndim in {2, 3}
         | 
| 262 | 
            +
             | 
| 263 | 
            +
             | 
| 264 | 
            +
            def is_gray_image(image, tol=3):
         | 
| 265 | 
            +
                assert is_numpy_image(image)
         | 
| 266 | 
            +
                if image.ndim == 2:
         | 
| 267 | 
            +
                    return True
         | 
| 268 | 
            +
                elif image.ndim == 3:
         | 
| 269 | 
            +
                    num_channels = image.shape[-1]
         | 
| 270 | 
            +
                    if num_channels == 1:
         | 
| 271 | 
            +
                        return True
         | 
| 272 | 
            +
                    elif num_channels == 3:
         | 
| 273 | 
            +
                        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         | 
| 274 | 
            +
                        gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
         | 
| 275 | 
            +
                        mae = np.mean(cv2.absdiff(image, gray3))
         | 
| 276 | 
            +
                        return mae <= tol
         | 
| 277 | 
            +
                    elif num_channels == 4:
         | 
| 278 | 
            +
                        rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
         | 
| 279 | 
            +
                        gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
         | 
| 280 | 
            +
                        gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
         | 
| 281 | 
            +
                        mae = np.mean(cv2.absdiff(rgb, gray3))
         | 
| 282 | 
            +
                        return mae <= tol
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        return False
         | 
| 285 | 
            +
                else:
         | 
| 286 | 
            +
                    return False
         | 
| 287 | 
            +
                    
         | 
| 288 | 
            +
             | 
| 289 | 
            +
            def is_solid_color_image(image, tol=4):
         | 
| 290 | 
            +
                assert is_numpy_image(image)
         | 
| 291 | 
            +
                mean = np.array(cv2.mean(image)[:-1], dtype=np.float32)
         | 
| 292 | 
            +
                
         | 
| 293 | 
            +
                if image.ndim == 2:
         | 
| 294 | 
            +
                    mae = np.mean(np.abs(image - mean[0]))
         | 
| 295 | 
            +
                    return mae <= tol
         | 
| 296 | 
            +
                elif image.ndim == 3:
         | 
| 297 | 
            +
                    num_channels = image.shape[-1]
         | 
| 298 | 
            +
                    if num_channels == 1:
         | 
| 299 | 
            +
                        mae = np.mean(np.abs(image - mean[0]))
         | 
| 300 | 
            +
                        return mae <= tol
         | 
| 301 | 
            +
                    elif num_channels == 3:
         | 
| 302 | 
            +
                        mae = np.mean(np.abs(image - mean))
         | 
| 303 | 
            +
                        return mae <= tol
         | 
| 304 | 
            +
                    elif num_channels == 4:
         | 
| 305 | 
            +
                        mae = np.mean(np.abs(image[:,:,:-1] - mean))
         | 
| 306 | 
            +
                        return mae <= tol
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        return False
         | 
| 309 | 
            +
                else:
         | 
| 310 | 
            +
                    return False
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
            def create_solid_color_image(image_width, image_height, color, dtype=None):
         | 
| 314 | 
            +
                if isinstance(color, numbers.Real):
         | 
| 315 | 
            +
                    image = np.full((image_height, image_width), color, dtype=dtype)
         | 
| 316 | 
            +
                elif isinstance(color, (tuple, list)):
         | 
| 317 | 
            +
                    if len(color) == 1:
         | 
| 318 | 
            +
                        image = np.full((image_height, image_width), color[0], dtype=dtype)
         | 
| 319 | 
            +
                    elif len(color) in (3, 4):
         | 
| 320 | 
            +
                        image = np.full((1, 1, len(color)), color, dtype=dtype)
         | 
| 321 | 
            +
                        image = cv2.copyMakeBorder(image, 0, image_height-1, 0, image_width-1, 
         | 
| 322 | 
            +
                                                   cv2.BORDER_CONSTANT, value=color)
         | 
| 323 | 
            +
                    else:
         | 
| 324 | 
            +
                        color = np.asarray(color, dtype=dtype)
         | 
| 325 | 
            +
                        image = np.empty((image_height, image_width, len(color)), dtype=dtype)
         | 
| 326 | 
            +
                        image[:] = color
         | 
| 327 | 
            +
                else:
         | 
| 328 | 
            +
                    raise TypeError(f'Invalid type {type(color)} for `color`.')
         | 
| 329 | 
            +
                return image
         | 
    	
        khandy/image/resize.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import warnings
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            import khandy
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            interp_codes = {
         | 
| 9 | 
            +
                'nearest': cv2.INTER_NEAREST,
         | 
| 10 | 
            +
                'bilinear': cv2.INTER_LINEAR,
         | 
| 11 | 
            +
                'bicubic': cv2.INTER_CUBIC,
         | 
| 12 | 
            +
                'area': cv2.INTER_AREA,
         | 
| 13 | 
            +
                'lanczos': cv2.INTER_LANCZOS4
         | 
| 14 | 
            +
            }
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def scale_image(image, x_scale, y_scale, interpolation='bilinear'):
         | 
| 18 | 
            +
                """Scale image.
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                Reference:
         | 
| 21 | 
            +
                    mmcv.imrescale
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 24 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 25 | 
            +
                dst_width = int(round(x_scale * src_width))
         | 
| 26 | 
            +
                dst_height = int(round(y_scale * src_height))
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                resized_image = cv2.resize(image, (dst_width, dst_height), 
         | 
| 29 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 30 | 
            +
                return resized_image
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def resize_image(image, dst_width, dst_height, return_scale=False, interpolation='bilinear'):
         | 
| 34 | 
            +
                """Resize image to a given size.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                    image (ndarray): The input image.
         | 
| 38 | 
            +
                    dst_width (int): Target width.
         | 
| 39 | 
            +
                    dst_height (int): Target height.
         | 
| 40 | 
            +
                    return_scale (bool): Whether to return `x_scale` and `y_scale`.
         | 
| 41 | 
            +
                    interpolation (str): Interpolation method, accepted values are
         | 
| 42 | 
            +
                        "nearest", "bilinear", "bicubic", "area", "lanczos".
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Returns:
         | 
| 45 | 
            +
                    tuple or ndarray: (`resized_image`, `x_scale`, `y_scale`) or `resized_image`.
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                Reference:
         | 
| 48 | 
            +
                    mmcv.imresize
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 51 | 
            +
                resized_image = cv2.resize(image, (dst_width, dst_height), 
         | 
| 52 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 53 | 
            +
                if not return_scale:
         | 
| 54 | 
            +
                    return resized_image
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    src_height, src_width = image.shape[:2]
         | 
| 57 | 
            +
                    x_scale = dst_width / src_width
         | 
| 58 | 
            +
                    y_scale = dst_height / src_height
         | 
| 59 | 
            +
                    return resized_image, x_scale, y_scale
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
            def resize_image_short(image, dst_size, return_scale=False, interpolation='bilinear'):
         | 
| 63 | 
            +
                """Resize an image so that the length of shorter side is dst_size while 
         | 
| 64 | 
            +
                preserving the original aspect ratio.
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                References:
         | 
| 67 | 
            +
                    `resize_min` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 70 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 71 | 
            +
                scale = max(dst_size / src_width, dst_size / src_height)
         | 
| 72 | 
            +
                dst_width = int(round(scale * src_width))
         | 
| 73 | 
            +
                dst_height = int(round(scale * src_height))
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                resized_image = cv2.resize(image, (dst_width, dst_height), 
         | 
| 76 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 77 | 
            +
                if not return_scale:
         | 
| 78 | 
            +
                    return resized_image
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    return resized_image, scale
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
            def resize_image_long(image, dst_size, return_scale=False, interpolation='bilinear'):
         | 
| 84 | 
            +
                """Resize an image so that the length of longer side is dst_size while 
         | 
| 85 | 
            +
                preserving the original aspect ratio.
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                References:
         | 
| 88 | 
            +
                    `resize_max` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 91 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 92 | 
            +
                scale = min(dst_size / src_width, dst_size / src_height)
         | 
| 93 | 
            +
                dst_width = int(round(scale * src_width))
         | 
| 94 | 
            +
                dst_height = int(round(scale * src_height))
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                resized_image = cv2.resize(image, (dst_width, dst_height), 
         | 
| 97 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 98 | 
            +
                if not return_scale:
         | 
| 99 | 
            +
                    return resized_image
         | 
| 100 | 
            +
                else:
         | 
| 101 | 
            +
                    return resized_image, scale
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    
         | 
| 104 | 
            +
            def resize_image_to_range(image, min_length, max_length, return_scale=False, interpolation='bilinear'):
         | 
| 105 | 
            +
                """Resizes an image so its dimensions are within the provided value.
         | 
| 106 | 
            +
                
         | 
| 107 | 
            +
                Rescale the shortest side of the image up to `min_length` pixels 
         | 
| 108 | 
            +
                while keeping the largest side below `max_length` pixels without 
         | 
| 109 | 
            +
                changing the aspect ratio. Often used in object detection (e.g. RCNN and SSH.)
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                The output size can be described by two cases:
         | 
| 112 | 
            +
                1. If the image can be rescaled so its shortest side is equal to the
         | 
| 113 | 
            +
                    `min_length` without the other side exceeding `max_length`, then do so.
         | 
| 114 | 
            +
                2. Otherwise, resize so the longest side is equal to `max_length`.
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                Returns:
         | 
| 117 | 
            +
                    resized_image: resized image so that
         | 
| 118 | 
            +
                        min(dst_height, dst_width) == min_length or
         | 
| 119 | 
            +
                        max(dst_height, dst_width) == max_length.
         | 
| 120 | 
            +
                      
         | 
| 121 | 
            +
                References:
         | 
| 122 | 
            +
                    `resize_to_range` in `models-master/research/object_detection/core/preprocessor.py`
         | 
| 123 | 
            +
                    `prep_im_for_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
         | 
| 124 | 
            +
                    mmcv.imrescale
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 127 | 
            +
                assert min_length < max_length
         | 
| 128 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                min_side_length = min(src_width, src_height)
         | 
| 131 | 
            +
                max_side_length = max(src_width, src_height)
         | 
| 132 | 
            +
                scale = min_length / min_side_length
         | 
| 133 | 
            +
                if round(scale * max_side_length) > max_length:
         | 
| 134 | 
            +
                    scale = max_length / max_side_length
         | 
| 135 | 
            +
                dst_width = int(round(scale * src_width))
         | 
| 136 | 
            +
                dst_height = int(round(scale * src_height))
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                resized_image = cv2.resize(image, (dst_width, dst_height), 
         | 
| 139 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 140 | 
            +
                if not return_scale:
         | 
| 141 | 
            +
                    return resized_image
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    return resized_image, scale
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
                    
         | 
| 146 | 
            +
            def letterbox_image(image, dst_width, dst_height, border_value=0,
         | 
| 147 | 
            +
                                return_scale=False, interpolation='bilinear'):
         | 
| 148 | 
            +
                """Resize an image preserving the original aspect ratio using padding.
         | 
| 149 | 
            +
                
         | 
| 150 | 
            +
                References:
         | 
| 151 | 
            +
                    `letterbox_image` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 154 | 
            +
                src_height, src_width = image.shape[:2]
         | 
| 155 | 
            +
                scale = min(dst_width / src_width, dst_height / src_height)
         | 
| 156 | 
            +
                resize_w = int(round(scale * src_width))
         | 
| 157 | 
            +
                resize_h = int(round(scale * src_height))
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                resized_image = cv2.resize(image, (resize_w, resize_h), 
         | 
| 160 | 
            +
                                           interpolation=interp_codes[interpolation])
         | 
| 161 | 
            +
                pad_top = (dst_height - resize_h) // 2
         | 
| 162 | 
            +
                pad_bottom = (dst_height - resize_h) - pad_top
         | 
| 163 | 
            +
                pad_left = (dst_width - resize_w) // 2
         | 
| 164 | 
            +
                pad_right = (dst_width - resize_w) - pad_left
         | 
| 165 | 
            +
                padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right, 
         | 
| 166 | 
            +
                                                  cv2.BORDER_CONSTANT, value=border_value)
         | 
| 167 | 
            +
                if not return_scale:
         | 
| 168 | 
            +
                    return padded_image
         | 
| 169 | 
            +
                else:
         | 
| 170 | 
            +
                    return padded_image, scale, pad_left, pad_top
         | 
| 171 | 
            +
                    
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            def letterbox_resize_image(image, dst_width, dst_height, border_value=0,
         | 
| 174 | 
            +
                                       return_scale=False, interpolation='bilinear'):
         | 
| 175 | 
            +
                warnings.warn('letterbox_resize_image will be deprecated, use letterbox_image instead!')
         | 
| 176 | 
            +
                return letterbox_image(image, dst_width, dst_height, border_value,
         | 
| 177 | 
            +
                                       return_scale, interpolation)
         | 
    	
        khandy/image/rotate.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import khandy
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def get_2d_rotation_matrix(angle, cx=0, cy=0, scale=1, 
         | 
| 7 | 
            +
                                       degrees=True, dtype=np.float32):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                References:
         | 
| 10 | 
            +
                    `cv2.getRotationMatrix2D` in OpenCV
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                if degrees:
         | 
| 13 | 
            +
                    angle = np.deg2rad(angle)
         | 
| 14 | 
            +
                c = scale * np.cos(angle)
         | 
| 15 | 
            +
                s = scale * np.sin(angle)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                tx = cx - cx * c + cy * s
         | 
| 18 | 
            +
                ty = cy - cx * s - cy * c
         | 
| 19 | 
            +
                return np.array([[ c, -s, tx],
         | 
| 20 | 
            +
                                 [ s,  c, ty],
         | 
| 21 | 
            +
                                 [ 0,  0, 1]], dtype=dtype)
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
            def rotate_image(image, angle, scale=1.0, center=None, 
         | 
| 25 | 
            +
                             degrees=True, border_value=0, auto_bound=False):
         | 
| 26 | 
            +
                """Rotate an image.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                Args:
         | 
| 29 | 
            +
                    image : ndarray
         | 
| 30 | 
            +
                        Image to be rotated.
         | 
| 31 | 
            +
                    angle : float
         | 
| 32 | 
            +
                        Rotation angle in degrees, positive values mean clockwise rotation.
         | 
| 33 | 
            +
                    center : tuple
         | 
| 34 | 
            +
                        Center of the rotation in the source image, by default
         | 
| 35 | 
            +
                        it is the center of the image.
         | 
| 36 | 
            +
                    scale : float
         | 
| 37 | 
            +
                        Isotropic scale factor.
         | 
| 38 | 
            +
                    degrees : bool
         | 
| 39 | 
            +
                    border_value : int
         | 
| 40 | 
            +
                        Border value.
         | 
| 41 | 
            +
                    auto_bound : bool
         | 
| 42 | 
            +
                        Whether to adjust the image size to cover the whole rotated image.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Returns:
         | 
| 45 | 
            +
                    ndarray: The rotated image.
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                References:
         | 
| 48 | 
            +
                    mmcv.imrotate
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 51 | 
            +
                image_height, image_width = image.shape[:2]
         | 
| 52 | 
            +
                if auto_bound:
         | 
| 53 | 
            +
                    center = None
         | 
| 54 | 
            +
                if center is None:
         | 
| 55 | 
            +
                    center = ((image_width - 1) * 0.5, (image_height - 1) * 0.5)
         | 
| 56 | 
            +
                assert isinstance(center, tuple)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                rotation_matrix = get_2d_rotation_matrix(angle, center[0], center[1], scale, degrees)
         | 
| 59 | 
            +
                if auto_bound:
         | 
| 60 | 
            +
                    scale_cos = np.abs(rotation_matrix[0, 0])
         | 
| 61 | 
            +
                    scale_sin = np.abs(rotation_matrix[0, 1])
         | 
| 62 | 
            +
                    new_width = image_width * scale_cos + image_height * scale_sin
         | 
| 63 | 
            +
                    new_height = image_width * scale_sin + image_height * scale_cos
         | 
| 64 | 
            +
                    
         | 
| 65 | 
            +
                    rotation_matrix[0, 2] += (new_width - image_width) * 0.5
         | 
| 66 | 
            +
                    rotation_matrix[1, 2] += (new_height - image_height) * 0.5
         | 
| 67 | 
            +
                    
         | 
| 68 | 
            +
                    image_width = int(np.round(new_width))
         | 
| 69 | 
            +
                    image_height = int(np.round(new_height))
         | 
| 70 | 
            +
                rotated = cv2.warpAffine(image, rotation_matrix[:2,:], (image_width, image_height), 
         | 
| 71 | 
            +
                                         borderValue=border_value)
         | 
| 72 | 
            +
                return rotated
         | 
    	
        khandy/image/translate.py
    ADDED
    
    | @@ -0,0 +1,57 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numbers
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import khandy
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def translate_image(image, x_shift, y_shift, border_value=0):
         | 
| 7 | 
            +
                """Translate an image.
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                    image (ndarray): Image to be translated with format (h, w) or (h, w, c).
         | 
| 11 | 
            +
                    x_shift (int): The offset used for translate in horizontal
         | 
| 12 | 
            +
                        direction. right is the positive direction.
         | 
| 13 | 
            +
                    y_shift (int): The offset used for translate in vertical
         | 
| 14 | 
            +
                        direction. down is the positive direction.
         | 
| 15 | 
            +
                    border_value (int | tuple[int]): Value used in case of a 
         | 
| 16 | 
            +
                        constant border.
         | 
| 17 | 
            +
                        
         | 
| 18 | 
            +
                Returns:
         | 
| 19 | 
            +
                    ndarray: The translated image.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                See Also:
         | 
| 22 | 
            +
                    crop_or_pad
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                assert khandy.is_numpy_image(image)
         | 
| 25 | 
            +
                assert isinstance(x_shift, numbers.Integral)
         | 
| 26 | 
            +
                assert isinstance(y_shift, numbers.Integral)
         | 
| 27 | 
            +
                image_height, image_width = image.shape[:2]
         | 
| 28 | 
            +
                channels = 1 if image.ndim == 2 else image.shape[2]
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                if isinstance(border_value, (tuple, list)):
         | 
| 31 | 
            +
                    assert len(border_value) == channels, \
         | 
| 32 | 
            +
                        'Expected the num of elements in tuple equals the channels ' \
         | 
| 33 | 
            +
                        'of input image. Found {} vs {}'.format(
         | 
| 34 | 
            +
                            len(border_value), channels)
         | 
| 35 | 
            +
                else:
         | 
| 36 | 
            +
                    border_value = (border_value,) * channels
         | 
| 37 | 
            +
                dst_image = khandy.create_solid_color_image(
         | 
| 38 | 
            +
                    image_height, image_width, border_value, dtype=image.dtype)
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                if (abs(x_shift) >= image_width) or (abs(y_shift) >= image_height):
         | 
| 41 | 
            +
                    return dst_image
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                src_x_begin = max(-x_shift, 0)
         | 
| 44 | 
            +
                src_x_end   = min(image_width - x_shift, image_width)
         | 
| 45 | 
            +
                dst_x_begin = max(x_shift, 0)
         | 
| 46 | 
            +
                dst_x_end   = min(image_width + x_shift, image_width)
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                src_y_begin = max(-y_shift, 0)
         | 
| 49 | 
            +
                src_y_end   = min(image_height - y_shift, image_height)
         | 
| 50 | 
            +
                dst_y_begin = max(y_shift, 0)
         | 
| 51 | 
            +
                dst_y_end   = min(image_height + y_shift, image_height)
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                dst_image[dst_y_begin:dst_y_end, dst_x_begin:dst_x_end] = \
         | 
| 54 | 
            +
                    image[src_y_begin:src_y_end, src_x_begin:src_x_end]
         | 
| 55 | 
            +
                return dst_image
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                
         | 
    	
        khandy/label/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .detect import *
         | 
| 2 | 
            +
             | 
    	
        khandy/label/detect.py
    ADDED
    
    | @@ -0,0 +1,594 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import dataclasses
         | 
| 5 | 
            +
            from dataclasses import dataclass, field
         | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
            from typing import Optional, List
         | 
| 8 | 
            +
            import xml.etree.ElementTree as ET
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import khandy
         | 
| 11 | 
            +
            import lxml
         | 
| 12 | 
            +
            import lxml.builder
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
         | 
| 17 | 
            +
                       'save_detect', 'convert_detect', 'replace_detect_label',
         | 
| 18 | 
            +
                       'load_coco_class_names']
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            @dataclass
         | 
| 22 | 
            +
            class DetectIrObject:
         | 
| 23 | 
            +
                """Intermediate Representation Format of Object
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                label: str
         | 
| 26 | 
            +
                x_min: float
         | 
| 27 | 
            +
                y_min: float
         | 
| 28 | 
            +
                x_max: float
         | 
| 29 | 
            +
                y_max: float
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            @dataclass
         | 
| 33 | 
            +
            class DetectIrRecord:
         | 
| 34 | 
            +
                """Intermediate Representation Format of Record
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                filename: str
         | 
| 37 | 
            +
                width: int
         | 
| 38 | 
            +
                height: int
         | 
| 39 | 
            +
                objects: List[DetectIrObject] = field(default_factory=list)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @dataclass
         | 
| 43 | 
            +
            class PascalVocSource:
         | 
| 44 | 
            +
                database: str = ''
         | 
| 45 | 
            +
                annotation: str = ''
         | 
| 46 | 
            +
                image: str = ''
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            @dataclass
         | 
| 50 | 
            +
            class PascalVocSize:
         | 
| 51 | 
            +
                height: int
         | 
| 52 | 
            +
                width: int
         | 
| 53 | 
            +
                depth: int
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            @dataclass
         | 
| 57 | 
            +
            class PascalVocBndbox:
         | 
| 58 | 
            +
                xmin: float
         | 
| 59 | 
            +
                ymin: float
         | 
| 60 | 
            +
                xmax: float
         | 
| 61 | 
            +
                ymax: float
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            @dataclass
         | 
| 65 | 
            +
            class PascalVocObject:
         | 
| 66 | 
            +
                name: str
         | 
| 67 | 
            +
                pose: str = 'Unspecified'
         | 
| 68 | 
            +
                truncated: int = 0
         | 
| 69 | 
            +
                difficult: int = 0
         | 
| 70 | 
            +
                bndbox: Optional[PascalVocBndbox] = None
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            @dataclass
         | 
| 74 | 
            +
            class PascalVocRecord:
         | 
| 75 | 
            +
                folder: str = ''
         | 
| 76 | 
            +
                filename: str = ''
         | 
| 77 | 
            +
                path: str = ''
         | 
| 78 | 
            +
                source: PascalVocSource = PascalVocSource()
         | 
| 79 | 
            +
                size: Optional[PascalVocSize] = None
         | 
| 80 | 
            +
                segmented: int = 0
         | 
| 81 | 
            +
                objects: List[PascalVocObject] = field(default_factory=list)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            class PascalVocHandler:
         | 
| 85 | 
            +
                @staticmethod
         | 
| 86 | 
            +
                def load(filename, **kwargs) -> PascalVocRecord:
         | 
| 87 | 
            +
                    pascal_voc_record = PascalVocRecord()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    xml_tree = ET.parse(filename)
         | 
| 90 | 
            +
                    pascal_voc_record.folder = xml_tree.find('folder').text
         | 
| 91 | 
            +
                    pascal_voc_record.filename = xml_tree.find('filename').text
         | 
| 92 | 
            +
                    pascal_voc_record.path = xml_tree.find('path').text
         | 
| 93 | 
            +
                    pascal_voc_record.segmented = xml_tree.find('segmented').text
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    source_tag = xml_tree.find('source')
         | 
| 96 | 
            +
                    pascal_voc_record.source = PascalVocSource(
         | 
| 97 | 
            +
                        database=source_tag.find('database').text,
         | 
| 98 | 
            +
                        # annotation=source_tag.find('annotation').text,
         | 
| 99 | 
            +
                        # image=source_tag.find('image').text
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    size_tag = xml_tree.find('size')
         | 
| 103 | 
            +
                    pascal_voc_record.size = PascalVocSize(
         | 
| 104 | 
            +
                        width=int(size_tag.find('width').text),
         | 
| 105 | 
            +
                        height=int(size_tag.find('height').text),
         | 
| 106 | 
            +
                        depth=int(size_tag.find('depth').text)
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    object_tags = xml_tree.findall('object')
         | 
| 110 | 
            +
                    for index, object_tag in enumerate(object_tags):
         | 
| 111 | 
            +
                        bndbox_tag = object_tag.find('bndbox')
         | 
| 112 | 
            +
                        bndbox = PascalVocBndbox(
         | 
| 113 | 
            +
                            xmin=float(bndbox_tag.find('xmin').text) - 1,
         | 
| 114 | 
            +
                            ymin=float(bndbox_tag.find('ymin').text) - 1,
         | 
| 115 | 
            +
                            xmax=float(bndbox_tag.find('xmax').text) - 1,
         | 
| 116 | 
            +
                            ymax=float(bndbox_tag.find('ymax').text) - 1
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                        pascal_voc_object = PascalVocObject(
         | 
| 119 | 
            +
                            name=object_tag.find('name').text,
         | 
| 120 | 
            +
                            pose=object_tag.find('pose').text,
         | 
| 121 | 
            +
                            truncated=object_tag.find('truncated').text,
         | 
| 122 | 
            +
                            difficult=object_tag.find('difficult').text,
         | 
| 123 | 
            +
                            bndbox=bndbox
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        pascal_voc_record.objects.append(pascal_voc_object)
         | 
| 126 | 
            +
                    return pascal_voc_record
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @staticmethod
         | 
| 129 | 
            +
                def save(filename, pascal_voc_record: PascalVocRecord):
         | 
| 130 | 
            +
                    maker = lxml.builder.ElementMaker()
         | 
| 131 | 
            +
                    xml = maker.annotation(
         | 
| 132 | 
            +
                        maker.folder(pascal_voc_record.folder),
         | 
| 133 | 
            +
                        maker.filename(pascal_voc_record.filename),
         | 
| 134 | 
            +
                        maker.path(pascal_voc_record.path),
         | 
| 135 | 
            +
                        maker.source(
         | 
| 136 | 
            +
                            maker.database(pascal_voc_record.source.database),
         | 
| 137 | 
            +
                        ),
         | 
| 138 | 
            +
                        maker.size(
         | 
| 139 | 
            +
                            maker.width(str(pascal_voc_record.size.width)),
         | 
| 140 | 
            +
                            maker.height(str(pascal_voc_record.size.height)),
         | 
| 141 | 
            +
                            maker.depth(str(pascal_voc_record.size.depth)),
         | 
| 142 | 
            +
                        ),
         | 
| 143 | 
            +
                        maker.segmented(str(pascal_voc_record.segmented)),
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    for pascal_voc_object in pascal_voc_record.objects:
         | 
| 147 | 
            +
                        object_tag = maker.object(
         | 
| 148 | 
            +
                            maker.name(pascal_voc_object.name),
         | 
| 149 | 
            +
                            maker.pose(pascal_voc_object.pose),
         | 
| 150 | 
            +
                            maker.truncated(str(pascal_voc_object.truncated)),
         | 
| 151 | 
            +
                            maker.difficult(str(pascal_voc_object.difficult)),
         | 
| 152 | 
            +
                            maker.bndbox(
         | 
| 153 | 
            +
                                maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
         | 
| 154 | 
            +
                                maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
         | 
| 155 | 
            +
                                maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
         | 
| 156 | 
            +
                                maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
         | 
| 157 | 
            +
                            ),
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        xml.append(object_tag)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    if not filename.endswith('.xml'):
         | 
| 162 | 
            +
                        filename = filename + '.xml'
         | 
| 163 | 
            +
                    with open(filename, 'wb') as f:
         | 
| 164 | 
            +
                        f.write(lxml.etree.tostring(
         | 
| 165 | 
            +
                            xml, pretty_print=True, encoding='utf-8'))
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                @staticmethod
         | 
| 168 | 
            +
                def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
         | 
| 169 | 
            +
                    ir_record = DetectIrRecord(
         | 
| 170 | 
            +
                        filename=pascal_voc_record.filename,
         | 
| 171 | 
            +
                        width=pascal_voc_record.size.width,
         | 
| 172 | 
            +
                        height=pascal_voc_record.size.height
         | 
| 173 | 
            +
                    )
         | 
| 174 | 
            +
                    for pascal_voc_object in pascal_voc_record.objects:
         | 
| 175 | 
            +
                        ir_object = DetectIrObject(
         | 
| 176 | 
            +
                            label=pascal_voc_object.name,
         | 
| 177 | 
            +
                            x_min=pascal_voc_object.bndbox.xmin,
         | 
| 178 | 
            +
                            y_min=pascal_voc_object.bndbox.ymin,
         | 
| 179 | 
            +
                            x_max=pascal_voc_object.bndbox.xmax,
         | 
| 180 | 
            +
                            y_max=pascal_voc_object.bndbox.ymax
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                        ir_record.objects.append(ir_object)
         | 
| 183 | 
            +
                    return ir_record
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                @staticmethod
         | 
| 186 | 
            +
                def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
         | 
| 187 | 
            +
                    pascal_voc_record = PascalVocRecord(
         | 
| 188 | 
            +
                        filename=ir_record.filename,
         | 
| 189 | 
            +
                        size=PascalVocSize(
         | 
| 190 | 
            +
                            width=ir_record.width,
         | 
| 191 | 
            +
                            height=ir_record.height,
         | 
| 192 | 
            +
                            depth=3
         | 
| 193 | 
            +
                        )
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
                    for ir_object in ir_record.objects:
         | 
| 196 | 
            +
                        pascal_voc_object = PascalVocObject(
         | 
| 197 | 
            +
                            name=ir_object.label,
         | 
| 198 | 
            +
                            bndbox=PascalVocBndbox(
         | 
| 199 | 
            +
                                xmin=ir_object.x_min,
         | 
| 200 | 
            +
                                ymin=ir_object.y_min,
         | 
| 201 | 
            +
                                xmax=ir_object.x_max,
         | 
| 202 | 
            +
                                ymax=ir_object.y_max,
         | 
| 203 | 
            +
                            )
         | 
| 204 | 
            +
                        )
         | 
| 205 | 
            +
                        pascal_voc_record.objects.append(pascal_voc_object)
         | 
| 206 | 
            +
                    return pascal_voc_record
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            class _NumpyEncoder(json.JSONEncoder):
         | 
| 210 | 
            +
                """ Special json encoder for numpy types """
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def default(self, obj):
         | 
| 213 | 
            +
                    if isinstance(obj, (np.bool_,)):
         | 
| 214 | 
            +
                        return bool(obj)
         | 
| 215 | 
            +
                    elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
         | 
| 216 | 
            +
                                          np.int16, np.int32, np.int64, np.uint8,
         | 
| 217 | 
            +
                                          np.uint16, np.uint32, np.uint64)):
         | 
| 218 | 
            +
                        return int(obj)
         | 
| 219 | 
            +
                    elif isinstance(obj, (np.float_, np.float16, np.float32,
         | 
| 220 | 
            +
                                          np.float64)):
         | 
| 221 | 
            +
                        return float(obj)
         | 
| 222 | 
            +
                    elif isinstance(obj, (np.ndarray,)):
         | 
| 223 | 
            +
                        return obj.tolist()
         | 
| 224 | 
            +
                    return json.JSONEncoder.default(self, obj)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            @dataclass
         | 
| 228 | 
            +
            class LabelmeShape:
         | 
| 229 | 
            +
                label: str
         | 
| 230 | 
            +
                points: np.ndarray
         | 
| 231 | 
            +
                shape_type: str
         | 
| 232 | 
            +
                flags: dict = field(default_factory=dict)
         | 
| 233 | 
            +
                group_id: Optional[int] = None
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def __post_init__(self):
         | 
| 236 | 
            +
                    self.points = np.asarray(self.points)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            @dataclass
         | 
| 240 | 
            +
            class LabelmeRecord:
         | 
| 241 | 
            +
                version: str = '4.5.6'
         | 
| 242 | 
            +
                flags: dict = field(default_factory=dict)
         | 
| 243 | 
            +
                shapes: List[LabelmeShape] = field(default_factory=list)
         | 
| 244 | 
            +
                imagePath: Optional[str] = None
         | 
| 245 | 
            +
                imageData: Optional[str] = None
         | 
| 246 | 
            +
                imageHeight: Optional[int] = None
         | 
| 247 | 
            +
                imageWidth: Optional[int] = None
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def __post_init__(self):
         | 
| 250 | 
            +
                    for k, shape in enumerate(self.shapes):
         | 
| 251 | 
            +
                        self.shapes[k] = LabelmeShape(**shape)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            class LabelmeHandler:
         | 
| 255 | 
            +
                @staticmethod
         | 
| 256 | 
            +
                def load(filename, **kwargs) -> LabelmeRecord:
         | 
| 257 | 
            +
                    json_content = khandy.load_json(filename)
         | 
| 258 | 
            +
                    return LabelmeRecord(**json_content)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                @staticmethod
         | 
| 261 | 
            +
                def save(filename, labelme_record: LabelmeRecord):
         | 
| 262 | 
            +
                    json_content = dataclasses.asdict(labelme_record)
         | 
| 263 | 
            +
                    khandy.save_json(filename, json_content, cls=_NumpyEncoder)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                @staticmethod
         | 
| 266 | 
            +
                def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
         | 
| 267 | 
            +
                    ir_record = DetectIrRecord(
         | 
| 268 | 
            +
                        filename=labelme_record.imagePath,
         | 
| 269 | 
            +
                        width=labelme_record.imageWidth,
         | 
| 270 | 
            +
                        height=labelme_record.imageHeight
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    for labelme_shape in labelme_record.shapes:
         | 
| 273 | 
            +
                        if labelme_shape.shape_type != 'rectangle':
         | 
| 274 | 
            +
                            continue
         | 
| 275 | 
            +
                        ir_object = DetectIrObject(
         | 
| 276 | 
            +
                            label=labelme_shape.label,
         | 
| 277 | 
            +
                            x_min=labelme_shape.points[0][0],
         | 
| 278 | 
            +
                            y_min=labelme_shape.points[0][1],
         | 
| 279 | 
            +
                            x_max=labelme_shape.points[1][0],
         | 
| 280 | 
            +
                            y_max=labelme_shape.points[1][1],
         | 
| 281 | 
            +
                        )
         | 
| 282 | 
            +
                        ir_record.objects.append(ir_object)
         | 
| 283 | 
            +
                    return ir_record
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                @staticmethod
         | 
| 286 | 
            +
                def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
         | 
| 287 | 
            +
                    labelme_record = LabelmeRecord(
         | 
| 288 | 
            +
                        imagePath=ir_record.filename,
         | 
| 289 | 
            +
                        imageWidth=ir_record.width,
         | 
| 290 | 
            +
                        imageHeight=ir_record.height
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
                    for ir_object in ir_record.objects:
         | 
| 293 | 
            +
                        labelme_shape = LabelmeShape(
         | 
| 294 | 
            +
                            label=ir_object.label,
         | 
| 295 | 
            +
                            shape_type='rectangle',
         | 
| 296 | 
            +
                            points=[[ir_object.x_min, ir_object.y_min],
         | 
| 297 | 
            +
                                    [ir_object.x_max, ir_object.y_max]]
         | 
| 298 | 
            +
                        )
         | 
| 299 | 
            +
                        labelme_record.shapes.append(labelme_shape)
         | 
| 300 | 
            +
                    return labelme_record
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            @dataclass
         | 
| 304 | 
            +
            class YoloObject:
         | 
| 305 | 
            +
                label: str
         | 
| 306 | 
            +
                x_center: float
         | 
| 307 | 
            +
                y_center: float
         | 
| 308 | 
            +
                width: float
         | 
| 309 | 
            +
                height: float
         | 
| 310 | 
            +
             | 
| 311 | 
            +
             | 
| 312 | 
            +
            @dataclass
         | 
| 313 | 
            +
            class YoloRecord:
         | 
| 314 | 
            +
                filename: Optional[str] = None
         | 
| 315 | 
            +
                width: Optional[int] = None
         | 
| 316 | 
            +
                height: Optional[int] = None
         | 
| 317 | 
            +
                objects: List[YoloObject] = field(default_factory=list)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
             | 
| 320 | 
            +
            class YoloHandler:
         | 
| 321 | 
            +
                @staticmethod
         | 
| 322 | 
            +
                def load(filename, **kwargs) -> YoloRecord:
         | 
| 323 | 
            +
                    assert 'image_filename' in kwargs
         | 
| 324 | 
            +
                    assert 'width' in kwargs and 'height' in kwargs
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    records = khandy.load_list(filename)
         | 
| 327 | 
            +
                    yolo_record = YoloRecord(
         | 
| 328 | 
            +
                        filename=kwargs.get('image_filename'),
         | 
| 329 | 
            +
                        width=kwargs.get('width'),
         | 
| 330 | 
            +
                        height=kwargs.get('height'))
         | 
| 331 | 
            +
                    for record in records:
         | 
| 332 | 
            +
                        record_parts = record.split()
         | 
| 333 | 
            +
                        yolo_record.objects.append(YoloObject(
         | 
| 334 | 
            +
                            label=record_parts[0],
         | 
| 335 | 
            +
                            x_center=float(record_parts[1]),
         | 
| 336 | 
            +
                            y_center=float(record_parts[2]),
         | 
| 337 | 
            +
                            width=float(record_parts[3]),
         | 
| 338 | 
            +
                            height=float(record_parts[4]),
         | 
| 339 | 
            +
                        ))
         | 
| 340 | 
            +
                    return yolo_record
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                @staticmethod
         | 
| 343 | 
            +
                def save(filename, yolo_record: YoloRecord):
         | 
| 344 | 
            +
                    records = []
         | 
| 345 | 
            +
                    for object in yolo_record.objects:
         | 
| 346 | 
            +
                        records.append(
         | 
| 347 | 
            +
                            f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
         | 
| 348 | 
            +
                    if not filename.endswith('.txt'):
         | 
| 349 | 
            +
                        filename = filename + '.txt'
         | 
| 350 | 
            +
                    khandy.save_list(filename, records)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                @staticmethod
         | 
| 353 | 
            +
                def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
         | 
| 354 | 
            +
                    ir_record = DetectIrRecord(
         | 
| 355 | 
            +
                        filename=yolo_record.filename,
         | 
| 356 | 
            +
                        width=yolo_record.width,
         | 
| 357 | 
            +
                        height=yolo_record.height
         | 
| 358 | 
            +
                    )
         | 
| 359 | 
            +
                    for yolo_object in yolo_record.objects:
         | 
| 360 | 
            +
                        x_min = (yolo_object.x_center - 0.5 *
         | 
| 361 | 
            +
                                 yolo_object.width) * yolo_record.width
         | 
| 362 | 
            +
                        y_min = (yolo_object.y_center - 0.5 *
         | 
| 363 | 
            +
                                 yolo_object.height) * yolo_record.height
         | 
| 364 | 
            +
                        x_max = (yolo_object.x_center + 0.5 *
         | 
| 365 | 
            +
                                 yolo_object.width) * yolo_record.width
         | 
| 366 | 
            +
                        y_max = (yolo_object.y_center + 0.5 *
         | 
| 367 | 
            +
                                 yolo_object.height) * yolo_record.height
         | 
| 368 | 
            +
                        ir_object = DetectIrObject(
         | 
| 369 | 
            +
                            label=yolo_object.label,
         | 
| 370 | 
            +
                            x_min=x_min,
         | 
| 371 | 
            +
                            y_min=y_min,
         | 
| 372 | 
            +
                            x_max=x_max,
         | 
| 373 | 
            +
                            y_max=y_max
         | 
| 374 | 
            +
                        )
         | 
| 375 | 
            +
                        ir_record.objects.append(ir_object)
         | 
| 376 | 
            +
                    return ir_record
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                @staticmethod
         | 
| 379 | 
            +
                def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
         | 
| 380 | 
            +
                    yolo_record = YoloRecord(
         | 
| 381 | 
            +
                        filename=ir_record.filename,
         | 
| 382 | 
            +
                        width=ir_record.width,
         | 
| 383 | 
            +
                        height=ir_record.height
         | 
| 384 | 
            +
                    )
         | 
| 385 | 
            +
                    for ir_object in ir_record.objects:
         | 
| 386 | 
            +
                        x_center = (ir_object.x_max + ir_object.x_min) / \
         | 
| 387 | 
            +
                            (2 * ir_record.width)
         | 
| 388 | 
            +
                        y_center = (ir_object.y_max + ir_object.y_min) / \
         | 
| 389 | 
            +
                            (2 * ir_record.height)
         | 
| 390 | 
            +
                        width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
         | 
| 391 | 
            +
                        height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
         | 
| 392 | 
            +
                        yolo_object = YoloObject(
         | 
| 393 | 
            +
                            label=ir_object.label,
         | 
| 394 | 
            +
                            x_center=x_center,
         | 
| 395 | 
            +
                            y_center=y_center,
         | 
| 396 | 
            +
                            width=width,
         | 
| 397 | 
            +
                            height=height,
         | 
| 398 | 
            +
                        )
         | 
| 399 | 
            +
                        yolo_record.objects.append(yolo_object)
         | 
| 400 | 
            +
                    return yolo_record
         | 
| 401 | 
            +
             | 
| 402 | 
            +
             | 
| 403 | 
            +
            @dataclass
         | 
| 404 | 
            +
            class CocoObject:
         | 
| 405 | 
            +
                label: str
         | 
| 406 | 
            +
                x_min: float
         | 
| 407 | 
            +
                y_min: float
         | 
| 408 | 
            +
                width: float
         | 
| 409 | 
            +
                height: float
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
            +
            @dataclass
         | 
| 413 | 
            +
            class CocoRecord:
         | 
| 414 | 
            +
                filename: str
         | 
| 415 | 
            +
                width: int
         | 
| 416 | 
            +
                height: int
         | 
| 417 | 
            +
                objects: List[CocoObject] = field(default_factory=list)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
             | 
| 420 | 
            +
            class CocoHandler:
         | 
| 421 | 
            +
                @staticmethod
         | 
| 422 | 
            +
                def load(filename, **kwargs) -> List[CocoRecord]:
         | 
| 423 | 
            +
                    json_data = khandy.load_json(filename)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    images = json_data['images']
         | 
| 426 | 
            +
                    annotations = json_data['annotations']
         | 
| 427 | 
            +
                    categories = json_data['categories']
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    label_map = {}
         | 
| 430 | 
            +
                    for cat_item in categories:
         | 
| 431 | 
            +
                        label_map[cat_item['id']] = cat_item['name']
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    coco_records = OrderedDict()
         | 
| 434 | 
            +
                    for image_item in images:
         | 
| 435 | 
            +
                        coco_records[image_item['id']] = CocoRecord(
         | 
| 436 | 
            +
                            filename=image_item['file_name'],
         | 
| 437 | 
            +
                            width=image_item['width'],
         | 
| 438 | 
            +
                            height=image_item['height'],
         | 
| 439 | 
            +
                            objects=[])
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    for annotation_item in annotations:
         | 
| 442 | 
            +
                        coco_object = CocoObject(
         | 
| 443 | 
            +
                            label=label_map[annotation_item['category_id']],
         | 
| 444 | 
            +
                            x_min=annotation_item['bbox'][0],
         | 
| 445 | 
            +
                            y_min=annotation_item['bbox'][1],
         | 
| 446 | 
            +
                            width=annotation_item['bbox'][2],
         | 
| 447 | 
            +
                            height=annotation_item['bbox'][3])
         | 
| 448 | 
            +
                        coco_records[annotation_item['image_id']
         | 
| 449 | 
            +
                                     ].objects.append(coco_object)
         | 
| 450 | 
            +
                    return list(coco_records.values())
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                @staticmethod
         | 
| 453 | 
            +
                def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
         | 
| 454 | 
            +
                    ir_record = DetectIrRecord(
         | 
| 455 | 
            +
                        filename=coco_record.filename,
         | 
| 456 | 
            +
                        width=coco_record.width,
         | 
| 457 | 
            +
                        height=coco_record.height,
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
            +
                    for coco_object in coco_record.objects:
         | 
| 460 | 
            +
                        ir_object = DetectIrObject(
         | 
| 461 | 
            +
                            label=coco_object.label,
         | 
| 462 | 
            +
                            x_min=coco_object.x_min,
         | 
| 463 | 
            +
                            y_min=coco_object.y_min,
         | 
| 464 | 
            +
                            x_max=coco_object.x_min + coco_object.width,
         | 
| 465 | 
            +
                            y_max=coco_object.y_min + coco_object.height
         | 
| 466 | 
            +
                        )
         | 
| 467 | 
            +
                        ir_record.objects.append(ir_object)
         | 
| 468 | 
            +
                    return ir_record
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                @staticmethod
         | 
| 471 | 
            +
                def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
         | 
| 472 | 
            +
                    coco_record = CocoRecord(
         | 
| 473 | 
            +
                        filename=ir_record.filename,
         | 
| 474 | 
            +
                        width=ir_record.width,
         | 
| 475 | 
            +
                        height=ir_record.height
         | 
| 476 | 
            +
                    )
         | 
| 477 | 
            +
                    for ir_object in ir_record.objects:
         | 
| 478 | 
            +
                        coco_object = CocoObject(
         | 
| 479 | 
            +
                            label=ir_object.label,
         | 
| 480 | 
            +
                            x_min=ir_object.x_min,
         | 
| 481 | 
            +
                            y_min=ir_object.y_min,
         | 
| 482 | 
            +
                            width=ir_object.x_max - ir_object.x_min,
         | 
| 483 | 
            +
                            height=ir_object.y_max - ir_object.y_min
         | 
| 484 | 
            +
                        )
         | 
| 485 | 
            +
                        coco_record.objects.append(coco_object)
         | 
| 486 | 
            +
                    return coco_record
         | 
| 487 | 
            +
             | 
| 488 | 
            +
             | 
| 489 | 
            +
            def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
         | 
| 490 | 
            +
                if fmt == 'labelme':
         | 
| 491 | 
            +
                    labelme_record = LabelmeHandler.load(filename, **kwargs)
         | 
| 492 | 
            +
                    ir_record = LabelmeHandler.to_ir(labelme_record)
         | 
| 493 | 
            +
                elif fmt == 'yolo':
         | 
| 494 | 
            +
                    yolo_record = YoloHandler.load(filename, **kwargs)
         | 
| 495 | 
            +
                    ir_record = YoloHandler.to_ir(yolo_record)
         | 
| 496 | 
            +
                elif fmt in ('voc', 'pascal', 'pascal_voc'):
         | 
| 497 | 
            +
                    pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
         | 
| 498 | 
            +
                    ir_record = PascalVocHandler.to_ir(pascal_voc_record)
         | 
| 499 | 
            +
                elif fmt == 'coco':
         | 
| 500 | 
            +
                    coco_records = CocoHandler.load(filename, **kwargs)
         | 
| 501 | 
            +
                    ir_record = [CocoHandler.to_ir(coco_record)
         | 
| 502 | 
            +
                                 for coco_record in coco_records]
         | 
| 503 | 
            +
                else:
         | 
| 504 | 
            +
                    raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
         | 
| 505 | 
            +
                return ir_record
         | 
| 506 | 
            +
             | 
| 507 | 
            +
             | 
| 508 | 
            +
            def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
         | 
| 509 | 
            +
                os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
         | 
| 510 | 
            +
                if out_fmt == 'labelme':
         | 
| 511 | 
            +
                    labelme_record = LabelmeHandler.from_ir(ir_record)
         | 
| 512 | 
            +
                    LabelmeHandler.save(filename, labelme_record)
         | 
| 513 | 
            +
                elif out_fmt == 'yolo':
         | 
| 514 | 
            +
                    yolo_record = YoloHandler.from_ir(ir_record)
         | 
| 515 | 
            +
                    YoloHandler.save(filename, yolo_record)
         | 
| 516 | 
            +
                elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
         | 
| 517 | 
            +
                    pascal_voc_record = PascalVocHandler.from_ir(ir_record)
         | 
| 518 | 
            +
                    PascalVocHandler.save(filename, pascal_voc_record)
         | 
| 519 | 
            +
                elif out_fmt == 'coco':
         | 
| 520 | 
            +
                    raise ValueError("Unsupported for `coco` now!")
         | 
| 521 | 
            +
                else:
         | 
| 522 | 
            +
                    raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
         | 
| 523 | 
            +
             | 
| 524 | 
            +
             | 
| 525 | 
            +
            def _get_format(record):
         | 
| 526 | 
            +
                if isinstance(record, LabelmeRecord):
         | 
| 527 | 
            +
                    return ('labelme',)
         | 
| 528 | 
            +
                elif isinstance(record, YoloRecord):
         | 
| 529 | 
            +
                    return ('yolo',)
         | 
| 530 | 
            +
                elif isinstance(record, PascalVocRecord):
         | 
| 531 | 
            +
                    return ('voc', 'pascal', 'pascal_voc')
         | 
| 532 | 
            +
                elif isinstance(record, CocoRecord):
         | 
| 533 | 
            +
                    return ('coco',)
         | 
| 534 | 
            +
                elif isinstance(record, DetectIrRecord):
         | 
| 535 | 
            +
                    return ('ir', 'detect_ir')
         | 
| 536 | 
            +
                else:
         | 
| 537 | 
            +
                    return ()
         | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
            def convert_detect(record, out_fmt):
         | 
| 541 | 
            +
                allowed_fmts = ('labelme', 'yolo', 'voc', 'coco',
         | 
| 542 | 
            +
                                'pascal', 'pascal_voc', 'ir', 'detect_ir')
         | 
| 543 | 
            +
                if out_fmt not in allowed_fmts:
         | 
| 544 | 
            +
                    raise ValueError(
         | 
| 545 | 
            +
                        "Unsupported label format conversions for given out_fmt")
         | 
| 546 | 
            +
                if out_fmt in _get_format(record):
         | 
| 547 | 
            +
                    return record
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                if isinstance(record, LabelmeRecord):
         | 
| 550 | 
            +
                    ir_record = LabelmeHandler.to_ir(record)
         | 
| 551 | 
            +
                elif isinstance(record, YoloRecord):
         | 
| 552 | 
            +
                    ir_record = YoloHandler.to_ir(record)
         | 
| 553 | 
            +
                elif isinstance(record, PascalVocRecord):
         | 
| 554 | 
            +
                    ir_record = PascalVocHandler.to_ir(record)
         | 
| 555 | 
            +
                elif isinstance(record, CocoRecord):
         | 
| 556 | 
            +
                    ir_record = CocoHandler.to_ir(record)
         | 
| 557 | 
            +
                elif isinstance(record, DetectIrRecord):
         | 
| 558 | 
            +
                    ir_record = record
         | 
| 559 | 
            +
                else:
         | 
| 560 | 
            +
                    raise TypeError('Unsupported type for record')
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                if out_fmt in ('ir', 'detect_ir'):
         | 
| 563 | 
            +
                    dst_record = ir_record
         | 
| 564 | 
            +
                elif out_fmt == 'labelme':
         | 
| 565 | 
            +
                    dst_record = LabelmeHandler.from_ir(ir_record)
         | 
| 566 | 
            +
                elif out_fmt == 'yolo':
         | 
| 567 | 
            +
                    dst_record = YoloHandler.from_ir(ir_record)
         | 
| 568 | 
            +
                elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
         | 
| 569 | 
            +
                    dst_record = PascalVocHandler.from_ir(ir_record)
         | 
| 570 | 
            +
                elif out_fmt == 'coco':
         | 
| 571 | 
            +
                    dst_record = CocoHandler.from_ir(ir_record)
         | 
| 572 | 
            +
                return dst_record
         | 
| 573 | 
            +
             | 
| 574 | 
            +
             | 
| 575 | 
            +
            def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
         | 
| 576 | 
            +
                dst_record = copy.deepcopy(record)
         | 
| 577 | 
            +
                dst_objects = []
         | 
| 578 | 
            +
                for ir_object in dst_record.objects:
         | 
| 579 | 
            +
                    if not ignore:
         | 
| 580 | 
            +
                        if ir_object.label in label_map:
         | 
| 581 | 
            +
                            ir_object.label = label_map[ir_object.label]
         | 
| 582 | 
            +
                        dst_objects.append(ir_object)
         | 
| 583 | 
            +
                    else:
         | 
| 584 | 
            +
                        if ir_object.label in label_map:
         | 
| 585 | 
            +
                            ir_object.label = label_map[ir_object.label]
         | 
| 586 | 
            +
                            dst_objects.append(ir_object)
         | 
| 587 | 
            +
                dst_record.objects = dst_objects
         | 
| 588 | 
            +
                return dst_record
         | 
| 589 | 
            +
             | 
| 590 | 
            +
             | 
| 591 | 
            +
            def load_coco_class_names(filename):
         | 
| 592 | 
            +
                json_data = khandy.load_json(filename)
         | 
| 593 | 
            +
                categories = json_data['categories']
         | 
| 594 | 
            +
                return [cat_item['name'] for cat_item in categories]
         | 
    	
        khandy/list_utils.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import itertools
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def to_list(obj):
         | 
| 6 | 
            +
                if obj is None:
         | 
| 7 | 
            +
                    return None
         | 
| 8 | 
            +
                elif hasattr(obj, '__iter__') and not isinstance(obj, str):
         | 
| 9 | 
            +
                    try:
         | 
| 10 | 
            +
                        return list(obj)
         | 
| 11 | 
            +
                    except:
         | 
| 12 | 
            +
                        return [obj]
         | 
| 13 | 
            +
                else:
         | 
| 14 | 
            +
                    return [obj]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def convert_lists_to_record(*list_objs, delimiter=None):
         | 
| 18 | 
            +
                assert len(list_objs) >= 1, 'list_objs length must >= 1.'
         | 
| 19 | 
            +
                delimiter = delimiter or ','
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                assert isinstance(list_objs[0], (tuple, list))
         | 
| 22 | 
            +
                number = len(list_objs[0])
         | 
| 23 | 
            +
                for item in list_objs[1:]:
         | 
| 24 | 
            +
                    assert isinstance(item, (tuple, list))
         | 
| 25 | 
            +
                    assert len(item) == number, '{} != {}'.format(len(item), number)
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                records = []
         | 
| 28 | 
            +
                record_list = zip(*list_objs)
         | 
| 29 | 
            +
                for record in record_list:
         | 
| 30 | 
            +
                    record_str = [str(item) for item in record]
         | 
| 31 | 
            +
                    records.append(delimiter.join(record_str))
         | 
| 32 | 
            +
                return records
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def shuffle_table(*table):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Notes:
         | 
| 38 | 
            +
                    table can be seen as list of list which have equal items.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                shuffled_list = list(zip(*table))
         | 
| 41 | 
            +
                random.shuffle(shuffled_list)
         | 
| 42 | 
            +
                tuple_list = zip(*shuffled_list)
         | 
| 43 | 
            +
                return [list(item) for item in tuple_list]
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
            def transpose_table(table):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Notes:
         | 
| 49 | 
            +
                    table can be seen as list of list which have equal items.
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                m, n = len(table), len(table[0])
         | 
| 52 | 
            +
                return [[table[i][j] for i in range(m)] for j in range(n)]
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def concat_list(in_list):
         | 
| 56 | 
            +
                """Concatenate a list of list into a single list.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                Args:
         | 
| 59 | 
            +
                    in_list (list): The list of list to be merged.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                Returns:
         | 
| 62 | 
            +
                    list: The concatenated flat list.
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                References:
         | 
| 65 | 
            +
                    mmcv.concat_list
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                return list(itertools.chain(*in_list))
         | 
| 68 | 
            +
                
         | 
    	
        khandy/misc.py
    ADDED
    
    | @@ -0,0 +1,245 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import socket
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import warnings
         | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import requests
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def all_of(iterable, pred):
         | 
| 12 | 
            +
                """Returns whether all elements in the iterable satisfy the predicate.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                    iterable (Iterable): An iterable to check.
         | 
| 16 | 
            +
                    pred (callable): A predicate to apply to each element.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Returns:
         | 
| 19 | 
            +
                    bool: True if all elements satisfy the predicate, False otherwise.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                References:
         | 
| 22 | 
            +
                    https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                return all(pred(element) for element in iterable)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def any_of(iterable, pred):
         | 
| 28 | 
            +
                """Returns whether any element in the iterable satisfies the predicate.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Args:
         | 
| 31 | 
            +
                    iterable (Iterable): An iterable to check.
         | 
| 32 | 
            +
                    pred (callable): A predicate to apply to each element.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                Returns:
         | 
| 35 | 
            +
                    bool: True if any element satisfies the predicate, False otherwise.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                References:
         | 
| 38 | 
            +
                    https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                return any(pred(element) for element in iterable)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def none_of(iterable, pred):
         | 
| 44 | 
            +
                """Returns whether no elements in the iterable satisfy the predicate.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    iterable (Iterable): An iterable to check.
         | 
| 48 | 
            +
                    pred (callable): A predicate to apply to each element.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                Returns:
         | 
| 51 | 
            +
                    bool: True if no elements satisfy the predicate, False otherwise.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                References:
         | 
| 54 | 
            +
                    https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                return not any(pred(element) for element in iterable)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def print_with_no(obj):
         | 
| 60 | 
            +
                if hasattr(obj, '__len__'):
         | 
| 61 | 
            +
                    for k, item in enumerate(obj):
         | 
| 62 | 
            +
                        print('[{}/{}] {}'.format(k+1, len(obj), item)) 
         | 
| 63 | 
            +
                elif hasattr(obj, '__iter__'):
         | 
| 64 | 
            +
                    for k, item in enumerate(obj):
         | 
| 65 | 
            +
                        print('[{}] {}'.format(k+1, item)) 
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    print('[1] {}'.format(obj))
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                  
         | 
| 70 | 
            +
            def get_file_line_count(filename, encoding='utf-8'):
         | 
| 71 | 
            +
                line_count = 0
         | 
| 72 | 
            +
                buffer_size = 1024 * 1024 * 8
         | 
| 73 | 
            +
                with open(filename, 'r', encoding=encoding) as f:
         | 
| 74 | 
            +
                    while True:
         | 
| 75 | 
            +
                        data = f.read(buffer_size)
         | 
| 76 | 
            +
                        if not data:
         | 
| 77 | 
            +
                            break
         | 
| 78 | 
            +
                        line_count += data.count('\n')
         | 
| 79 | 
            +
                return line_count
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def get_host_ip():
         | 
| 83 | 
            +
                try:
         | 
| 84 | 
            +
                    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         | 
| 85 | 
            +
                    s.connect(('8.8.8.8', 80))
         | 
| 86 | 
            +
                    ip = s.getsockname()[0]
         | 
| 87 | 
            +
                finally:
         | 
| 88 | 
            +
                    s.close()
         | 
| 89 | 
            +
                return ip
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def set_logger(filename, level=logging.INFO, logger_name=None, formatter=None, with_print=True):
         | 
| 93 | 
            +
                logger = logging.getLogger(logger_name) 
         | 
| 94 | 
            +
                logger.setLevel(level)
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                if formatter is None:
         | 
| 97 | 
            +
                    formatter = logging.Formatter('%(message)s')
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # Never mutate (insert/remove elements) the list you're currently iterating on. 
         | 
| 100 | 
            +
                # If you need, make a copy.
         | 
| 101 | 
            +
                for handler in logger.handlers[:]:
         | 
| 102 | 
            +
                    if isinstance(handler, logging.FileHandler):
         | 
| 103 | 
            +
                        logger.removeHandler(handler)
         | 
| 104 | 
            +
                    # FileHandler is subclass of StreamHandler, so isinstance(handler,
         | 
| 105 | 
            +
                    # logging.StreamHandler) is True even if handler is FileHandler.
         | 
| 106 | 
            +
                    # if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
         | 
| 107 | 
            +
                    elif type(handler) == logging.StreamHandler:
         | 
| 108 | 
            +
                        logger.removeHandler(handler)
         | 
| 109 | 
            +
                        
         | 
| 110 | 
            +
                file_handler = logging.FileHandler(filename, encoding='utf-8')
         | 
| 111 | 
            +
                file_handler.setFormatter(formatter)
         | 
| 112 | 
            +
                logger.addHandler(file_handler)
         | 
| 113 | 
            +
                if with_print:
         | 
| 114 | 
            +
                    console_handler = logging.StreamHandler()
         | 
| 115 | 
            +
                    console_handler.setFormatter(formatter)
         | 
| 116 | 
            +
                    logger.addHandler(console_handler)
         | 
| 117 | 
            +
                return logger
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def print_arguments(args):
         | 
| 121 | 
            +
                assert isinstance(args, argparse.Namespace)
         | 
| 122 | 
            +
                arg_list = sorted(vars(args).items())
         | 
| 123 | 
            +
                for key, value in arg_list:
         | 
| 124 | 
            +
                    print('{}: {}'.format(key, value))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def save_arguments(filename, args, sort=True):
         | 
| 128 | 
            +
                assert isinstance(args, argparse.Namespace)
         | 
| 129 | 
            +
                args = vars(args)
         | 
| 130 | 
            +
                with open(filename, 'w') as f:
         | 
| 131 | 
            +
                    json.dump(args, f, indent=4, sort_keys=sort)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class DownloadStatusCode(Enum):
         | 
| 135 | 
            +
                FILE_SIZE_TOO_LARGE = (-100, 'the size of file from url is too large')
         | 
| 136 | 
            +
                FILE_SIZE_TOO_SMALL = (-101, 'the size of file from url is too small')
         | 
| 137 | 
            +
                FILE_SIZE_IS_ZERO = (-102, 'the size of file from url is zero')
         | 
| 138 | 
            +
                URL_IS_NOT_IMAGE = (-103, 'URL is not an image')
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                @property
         | 
| 141 | 
            +
                def code(self):
         | 
| 142 | 
            +
                    return self.value[0]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                @property
         | 
| 145 | 
            +
                def message(self):
         | 
| 146 | 
            +
                    return self.value[1]
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            class DownloadError(Exception):
         | 
| 150 | 
            +
                def __init__(self, status_code: DownloadStatusCode, extra_str: str=None):
         | 
| 151 | 
            +
                    self.name = status_code.name
         | 
| 152 | 
            +
                    self.code = status_code.code
         | 
| 153 | 
            +
                    if extra_str is None:
         | 
| 154 | 
            +
                        self.message = status_code.message
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        self.message = f'{status_code.message}: {extra_str}'
         | 
| 157 | 
            +
                    Exception.__init__(self)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def __repr__(self):
         | 
| 160 | 
            +
                    return f'[{self.__class__.__name__} {self.code}] {self.message}'
         | 
| 161 | 
            +
                
         | 
| 162 | 
            +
                __str__ = __repr__
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                
         | 
| 165 | 
            +
            def download_image(image_url, min_filesize=0, max_filesize=100*1024*1024, 
         | 
| 166 | 
            +
                               params=None, **kwargs) -> bytes:
         | 
| 167 | 
            +
                """
         | 
| 168 | 
            +
                References:
         | 
| 169 | 
            +
                    https://httpwg.org/specs/rfc9110.html#field.content-length
         | 
| 170 | 
            +
                    https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                stream = kwargs.pop('stream', True)
         | 
| 173 | 
            +
                
         | 
| 174 | 
            +
                with requests.get(image_url, stream=stream, params=params, **kwargs) as response:
         | 
| 175 | 
            +
                    response.raise_for_status()
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    content_type = response.headers.get('content-type')
         | 
| 178 | 
            +
                    if content_type is None:
         | 
| 179 | 
            +
                        warnings.warn('No Content-Type!')
         | 
| 180 | 
            +
                    else:
         | 
| 181 | 
            +
                        if not content_type.startswith(('image/', 'application/octet-stream')):
         | 
| 182 | 
            +
                            raise DownloadError(DownloadStatusCode.URL_IS_NOT_IMAGE)
         | 
| 183 | 
            +
                    
         | 
| 184 | 
            +
                    # when Transfer-Encoding == chunked, Content-Length does not exist.
         | 
| 185 | 
            +
                    content_length = response.headers.get('content-length')
         | 
| 186 | 
            +
                    if content_length is None:
         | 
| 187 | 
            +
                        warnings.warn('No Content-Length!')
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        content_length = int(content_length)
         | 
| 190 | 
            +
                        if content_length > max_filesize:
         | 
| 191 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
         | 
| 192 | 
            +
                        if content_length < min_filesize:
         | 
| 193 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
         | 
| 194 | 
            +
                    
         | 
| 195 | 
            +
                    filesize = 0
         | 
| 196 | 
            +
                    chunks = []
         | 
| 197 | 
            +
                    for chunk in response.iter_content(chunk_size=10*1024):
         | 
| 198 | 
            +
                        chunks.append(chunk)
         | 
| 199 | 
            +
                        filesize += len(chunk)
         | 
| 200 | 
            +
                        if filesize > max_filesize:
         | 
| 201 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
         | 
| 202 | 
            +
                    if filesize < min_filesize:
         | 
| 203 | 
            +
                        raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
         | 
| 204 | 
            +
                    image_bytes = b''.join(chunks)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                return image_bytes
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
             | 
| 209 | 
            +
            def download_file(url, min_filesize=0, max_filesize=100*1024*1024, 
         | 
| 210 | 
            +
                              params=None, **kwargs) -> bytes:
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
                References:
         | 
| 213 | 
            +
                    https://httpwg.org/specs/rfc9110.html#field.content-length
         | 
| 214 | 
            +
                    https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
         | 
| 215 | 
            +
                """
         | 
| 216 | 
            +
                stream = kwargs.pop('stream', True)
         | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                with requests.get(url, stream=stream, params=params, **kwargs) as response:
         | 
| 219 | 
            +
                    response.raise_for_status()
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # when Transfer-Encoding == chunked, Content-Length does not exist.
         | 
| 222 | 
            +
                    content_length = response.headers.get('content-length')
         | 
| 223 | 
            +
                    if content_length is None:
         | 
| 224 | 
            +
                        warnings.warn('No Content-Length!')
         | 
| 225 | 
            +
                    else:
         | 
| 226 | 
            +
                        content_length = int(content_length)
         | 
| 227 | 
            +
                        if content_length > max_filesize:
         | 
| 228 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
         | 
| 229 | 
            +
                        if content_length < min_filesize:
         | 
| 230 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
         | 
| 231 | 
            +
                    
         | 
| 232 | 
            +
                    filesize = 0
         | 
| 233 | 
            +
                    chunks = []
         | 
| 234 | 
            +
                    for chunk in response.iter_content(chunk_size=10*1024):
         | 
| 235 | 
            +
                        chunks.append(chunk)
         | 
| 236 | 
            +
                        filesize += len(chunk)
         | 
| 237 | 
            +
                        if filesize > max_filesize:
         | 
| 238 | 
            +
                            raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
         | 
| 239 | 
            +
                    if filesize < min_filesize:
         | 
| 240 | 
            +
                        raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
         | 
| 241 | 
            +
                    file_bytes = b''.join(chunks)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                return file_bytes
         | 
| 244 | 
            +
                
         | 
| 245 | 
            +
             | 
    	
        khandy/numpy_utils.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def sigmoid(x):
         | 
| 5 | 
            +
                return 1. / (1 + np.exp(-x))
         | 
| 6 | 
            +
                
         | 
| 7 | 
            +
                
         | 
| 8 | 
            +
            def softmax(x, axis=-1, copy=True):
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                Args:
         | 
| 11 | 
            +
                    copy: Copy x or not.
         | 
| 12 | 
            +
                    
         | 
| 13 | 
            +
                Referneces:
         | 
| 14 | 
            +
                    `from sklearn.utils.extmath import softmax`
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                if copy:
         | 
| 17 | 
            +
                    x = np.copy(x)
         | 
| 18 | 
            +
                max_val = np.max(x, axis=axis, keepdims=True)
         | 
| 19 | 
            +
                x -= max_val
         | 
| 20 | 
            +
                np.exp(x, x)
         | 
| 21 | 
            +
                sum_exp = np.sum(x, axis=axis, keepdims=True)
         | 
| 22 | 
            +
                x /= sum_exp
         | 
| 23 | 
            +
                return x
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
            def log_sum_exp(x, axis=-1, keepdims=False):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                References:
         | 
| 29 | 
            +
                    numpy.logaddexp
         | 
| 30 | 
            +
                    numpy.logaddexp2
         | 
| 31 | 
            +
                    scipy.misc.logsumexp
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                max_val = np.max(x, axis=axis, keepdims=True)
         | 
| 34 | 
            +
                x -= max_val
         | 
| 35 | 
            +
                np.exp(x, x)
         | 
| 36 | 
            +
                sum_exp = np.sum(x, axis=axis, keepdims=keepdims)
         | 
| 37 | 
            +
                lse = np.log(sum_exp, sum_exp)
         | 
| 38 | 
            +
                if not keepdims:
         | 
| 39 | 
            +
                    max_val = np.squeeze(max_val, axis=axis)
         | 
| 40 | 
            +
                return max_val + lse
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
            def l2_normalize(x, axis=None, epsilon=1e-12, copy=True):
         | 
| 44 | 
            +
                """L2 normalize an array along an axis.
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    x : array_like of floats
         | 
| 48 | 
            +
                        Input data.
         | 
| 49 | 
            +
                    axis : None or int or tuple of ints, optional
         | 
| 50 | 
            +
                        Axis or axes along which to operate.
         | 
| 51 | 
            +
                    epsilon: float, optional
         | 
| 52 | 
            +
                        A small value such as to avoid division by zero.
         | 
| 53 | 
            +
                    copy : bool, optional
         | 
| 54 | 
            +
                        Copy x or not.
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                if copy:
         | 
| 57 | 
            +
                    x = np.copy(x)
         | 
| 58 | 
            +
                x /= np.maximum(np.linalg.norm(x, axis=axis, keepdims=True), epsilon)
         | 
| 59 | 
            +
                return x
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
            def minmax_normalize(x, axis=None, epsilon=1e-12, copy=True):
         | 
| 63 | 
            +
                """minmax normalize an array along a given axis.
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
                Args:
         | 
| 66 | 
            +
                    x : array_like of floats
         | 
| 67 | 
            +
                        Input data.
         | 
| 68 | 
            +
                    axis : None or int or tuple of ints, optional
         | 
| 69 | 
            +
                        Axis or axes along which to operate.
         | 
| 70 | 
            +
                    epsilon: float, optional
         | 
| 71 | 
            +
                        A small value such as to avoid division by zero.
         | 
| 72 | 
            +
                    copy : bool, optional
         | 
| 73 | 
            +
                        Copy x or not.
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                if copy:
         | 
| 76 | 
            +
                    x = np.copy(x)
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                minval = np.min(x, axis=axis, keepdims=True)
         | 
| 79 | 
            +
                maxval = np.max(x, axis=axis, keepdims=True)
         | 
| 80 | 
            +
                maxval -= minval
         | 
| 81 | 
            +
                maxval = np.maximum(maxval, epsilon)
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                x -= minval
         | 
| 84 | 
            +
                x /= maxval
         | 
| 85 | 
            +
                return x
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def zscore_normalize(x, mean=None, std=None, axis=None, epsilon=1e-12, copy=True):
         | 
| 89 | 
            +
                """z-score normalize an array along a given axis.
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                Args:
         | 
| 92 | 
            +
                    x : array_like of floats
         | 
| 93 | 
            +
                        Input data.
         | 
| 94 | 
            +
                    mean:  array_like of floats, optional
         | 
| 95 | 
            +
                        mean for z-score
         | 
| 96 | 
            +
                    std: array_like of floats, optional
         | 
| 97 | 
            +
                        std for z-score
         | 
| 98 | 
            +
                    axis : None or int or tuple of ints, optional
         | 
| 99 | 
            +
                        Axis or axes along which to operate.
         | 
| 100 | 
            +
                    epsilon: float, optional
         | 
| 101 | 
            +
                        A small value such as to avoid division by zero.
         | 
| 102 | 
            +
                    copy : bool, optional
         | 
| 103 | 
            +
                        Copy x or not.
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                if copy:
         | 
| 106 | 
            +
                    x = np.copy(x)
         | 
| 107 | 
            +
                if mean is None:
         | 
| 108 | 
            +
                    mean = np.mean(x, axis=axis, keepdims=True)
         | 
| 109 | 
            +
                if std is None:
         | 
| 110 | 
            +
                    std = np.std(x, axis=axis, keepdims=True)
         | 
| 111 | 
            +
                mean = np.asarray(mean, dtype=x.dtype)
         | 
| 112 | 
            +
                std = np.asarray(std, dtype=x.dtype)
         | 
| 113 | 
            +
                std = np.maximum(std, epsilon)
         | 
| 114 | 
            +
                
         | 
| 115 | 
            +
                x -= mean
         | 
| 116 | 
            +
                x /= std
         | 
| 117 | 
            +
                return x
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def get_order_of_magnitude(number):
         | 
| 121 | 
            +
                number = np.where(number == 0, 1, number)
         | 
| 122 | 
            +
                oom = np.floor(np.log10(np.abs(number)))
         | 
| 123 | 
            +
                return oom.astype(np.int32)
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
            def top_k(x, k, axis=-1, largest=True, sorted=True):
         | 
| 127 | 
            +
                """Finds values and indices of the k largest/smallest 
         | 
| 128 | 
            +
                elements along a given axis.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                Args:
         | 
| 131 | 
            +
                    x: numpy ndarray
         | 
| 132 | 
            +
                        1-D or higher with given axis at least k.
         | 
| 133 | 
            +
                    k: int
         | 
| 134 | 
            +
                        Number of top elements to look for along the given axis.
         | 
| 135 | 
            +
                    axis: int
         | 
| 136 | 
            +
                        The axis to sort along.
         | 
| 137 | 
            +
                    largest: bool
         | 
| 138 | 
            +
                        Controls whether to return largest or smallest elements
         | 
| 139 | 
            +
                    sorted: bool
         | 
| 140 | 
            +
                        If true the resulting k elements will be sorted by the values.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                Returns:
         | 
| 143 | 
            +
                    topk_values: 
         | 
| 144 | 
            +
                        The k largest/smallest elements along the given axis.
         | 
| 145 | 
            +
                    topk_indices: 
         | 
| 146 | 
            +
                        The indices of the k largest/smallest elements along the given axis.
         | 
| 147 | 
            +
                """
         | 
| 148 | 
            +
                if axis is None:
         | 
| 149 | 
            +
                    axis_size = x.size
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    axis_size = x.shape[axis]
         | 
| 152 | 
            +
                assert 1 <= k <= axis_size
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                x = np.asanyarray(x)
         | 
| 155 | 
            +
                if largest:
         | 
| 156 | 
            +
                    index_array = np.argpartition(x, axis_size-k, axis=axis)
         | 
| 157 | 
            +
                    topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
         | 
| 158 | 
            +
                else:
         | 
| 159 | 
            +
                    index_array = np.argpartition(x, k-1, axis=axis)
         | 
| 160 | 
            +
                    topk_indices = np.take(index_array, np.arange(k), axis=axis)
         | 
| 161 | 
            +
                topk_values = np.take_along_axis(x, topk_indices, axis=axis)
         | 
| 162 | 
            +
                if sorted:
         | 
| 163 | 
            +
                    sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
         | 
| 164 | 
            +
                    if largest:
         | 
| 165 | 
            +
                        sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
         | 
| 166 | 
            +
                    sorted_topk_values = np.take_along_axis(
         | 
| 167 | 
            +
                        topk_values, sorted_indices_in_topk, axis=axis)
         | 
| 168 | 
            +
                    sorted_topk_indices = np.take_along_axis(
         | 
| 169 | 
            +
                        topk_indices, sorted_indices_in_topk, axis=axis)
         | 
| 170 | 
            +
                    return sorted_topk_values, sorted_topk_indices
         | 
| 171 | 
            +
                return topk_values, topk_indices
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                
         | 
    	
        khandy/points/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .pts_letterbox import *
         | 
| 2 | 
            +
            from .pts_transform_scale import *
         | 
    	
        khandy/points/pts_letterbox.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __all__ = ['letterbox_2d_points', 'unletterbox_2d_points']
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def letterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
         | 
| 5 | 
            +
                if copy:
         | 
| 6 | 
            +
                    points = points.copy()
         | 
| 7 | 
            +
                points[..., 0::2] = points[..., 0::2] * scale + pad_left
         | 
| 8 | 
            +
                points[..., 1::2] = points[..., 1::2] * scale + pad_top
         | 
| 9 | 
            +
                return points
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
            def unletterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
         | 
| 13 | 
            +
                if copy:
         | 
| 14 | 
            +
                    points = points.copy()
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                points[..., 0::2] = (points[..., 0::2] - pad_left) / scale
         | 
| 17 | 
            +
                points[..., 1::2] = (points[..., 1::2] - pad_top) / scale
         | 
| 18 | 
            +
                return points
         | 
| 19 | 
            +
                
         | 
    	
        khandy/points/pts_transform_scale.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = ['scale_2d_points']
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def scale_2d_points(points, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
         | 
| 7 | 
            +
                """Scale 2d points.
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                    points: (..., 2N)
         | 
| 11 | 
            +
                    x_scale: scale factor in x dimension
         | 
| 12 | 
            +
                    y_scale: scale factor in y dimension
         | 
| 13 | 
            +
                    x_center: scale center in x dimension
         | 
| 14 | 
            +
                    y_center: scale center in y dimension
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                points = np.array(points, dtype=np.float32, copy=copy)
         | 
| 17 | 
            +
                x_scale = np.asarray(x_scale, np.float32)
         | 
| 18 | 
            +
                y_scale = np.asarray(y_scale, np.float32)
         | 
| 19 | 
            +
                x_center = np.asarray(x_center, np.float32)
         | 
| 20 | 
            +
                y_center = np.asarray(y_center, np.float32)
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                x_shift = 1 - x_scale
         | 
| 23 | 
            +
                y_shift = 1 - y_scale
         | 
| 24 | 
            +
                x_shift *= x_center
         | 
| 25 | 
            +
                y_shift *= y_center
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                points[..., 0::2] *= x_scale
         | 
| 28 | 
            +
                points[..., 1::2] *= y_scale
         | 
| 29 | 
            +
                points[..., 0::2] += x_shift
         | 
| 30 | 
            +
                points[..., 1::2] += y_shift
         | 
| 31 | 
            +
                return points
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                
         | 
    	
        khandy/split_utils.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numbers
         | 
| 2 | 
            +
            from collections.abc import Sequence
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def split_by_num(x, num_splits, strict=True):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                    num_splits: an integer indicating the number of splits
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                References:
         | 
| 13 | 
            +
                    numpy.split and numpy.array_split
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                # NB: np.ndarray is not Sequence
         | 
| 16 | 
            +
                assert isinstance(x, (Sequence, np.ndarray))
         | 
| 17 | 
            +
                assert isinstance(num_splits, numbers.Integral)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                if strict:
         | 
| 20 | 
            +
                    assert len(x) % num_splits == 0
         | 
| 21 | 
            +
                split_size = (len(x) + num_splits - 1) // num_splits
         | 
| 22 | 
            +
                out_list = []
         | 
| 23 | 
            +
                for i in range(0, len(x), split_size):
         | 
| 24 | 
            +
                    out_list.append(x[i: i + split_size])
         | 
| 25 | 
            +
                return out_list
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def split_by_size(x, sizes):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                References:
         | 
| 31 | 
            +
                    tf.split
         | 
| 32 | 
            +
                    https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                # NB: np.ndarray is not Sequence
         | 
| 35 | 
            +
                assert isinstance(x, (Sequence, np.ndarray))
         | 
| 36 | 
            +
                assert isinstance(sizes, (list, tuple))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                assert sum(sizes) == len(x)
         | 
| 39 | 
            +
                out_list = []
         | 
| 40 | 
            +
                start_index = 0
         | 
| 41 | 
            +
                for size in sizes:
         | 
| 42 | 
            +
                    out_list.append(x[start_index: start_index + size])
         | 
| 43 | 
            +
                    start_index += size
         | 
| 44 | 
            +
                return out_list
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def split_by_slice(x, slices):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                References:
         | 
| 50 | 
            +
                    SliceLayer in Caffe, and numpy.split
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                # NB: np.ndarray is not Sequence
         | 
| 53 | 
            +
                assert isinstance(x, (Sequence, np.ndarray))
         | 
| 54 | 
            +
                assert isinstance(slices, (list, tuple))
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                out_list = []
         | 
| 57 | 
            +
                indices = [0] + list(slices) + [len(x)]
         | 
| 58 | 
            +
                for i in range(len(slices) + 1):
         | 
| 59 | 
            +
                    out_list.append(x[indices[i]: indices[i + 1]])
         | 
| 60 | 
            +
                return out_list
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def split_by_ratio(x, ratios):
         | 
| 64 | 
            +
                # NB: np.ndarray is not Sequence
         | 
| 65 | 
            +
                assert isinstance(x, (Sequence, np.ndarray))
         | 
| 66 | 
            +
                assert isinstance(ratios, (list, tuple))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                pdf = [k / sum(ratios) for k in ratios]
         | 
| 69 | 
            +
                cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
         | 
| 70 | 
            +
                indices = [int(round(len(x) * k)) for k in cdf]
         | 
| 71 | 
            +
                return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
         | 
    	
        khandy/text_utils.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def strip_content_in_paren(string):
         | 
| 5 | 
            +
                """
         | 
| 6 | 
            +
                Notes:
         | 
| 7 | 
            +
                    strip_content_in_paren cannot process nested paren correctly
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                return re.sub(r"\([^)]*\)|([^)]*)", "", string)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def is_chinese_char(uchar: str) -> bool:
         | 
| 13 | 
            +
                """Whether the input char is a Chinese character.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                Args:
         | 
| 16 | 
            +
                    uchar: input char in unicode
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                References:
         | 
| 19 | 
            +
                    `is_chinese_char` in https://github.com/thunlp/OpenNRE/
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                codepoint = ord(uchar)
         | 
| 22 | 
            +
                if ((0x4E00 <= codepoint <= 0x9FFF) or # CJK Unified Ideographs
         | 
| 23 | 
            +
                    (0x3400 <= codepoint <= 0x4DBF) or # CJK Unified Ideographs Extension A
         | 
| 24 | 
            +
                    (0xF900 <= codepoint <= 0xFAFF) or # CJK Compatibility Ideographs
         | 
| 25 | 
            +
                    (0x20000 <= codepoint <= 0x2A6DF) or # CJK Unified Ideographs Extension B
         | 
| 26 | 
            +
                    (0x2A700 <= codepoint <= 0x2B73F) or
         | 
| 27 | 
            +
                    (0x2B740 <= codepoint <= 0x2B81F) or
         | 
| 28 | 
            +
                    (0x2B820 <= codepoint <= 0x2CEAF) or
         | 
| 29 | 
            +
                    (0x2F800 <= codepoint <= 0x2FA1F)): # CJK Compatibility Supplement
         | 
| 30 | 
            +
                    return True
         | 
| 31 | 
            +
                return False
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
    	
        khandy/time_utils.py
    ADDED
    
    | @@ -0,0 +1,101 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import numbers
         | 
| 4 | 
            +
            import datetime
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def _to_timestamp(val, multiplier=1, rounded=False):
         | 
| 8 | 
            +
                if val is None:
         | 
| 9 | 
            +
                    timestamp = time.time()
         | 
| 10 | 
            +
                elif isinstance(val, numbers.Real):
         | 
| 11 | 
            +
                    timestamp = float(val)
         | 
| 12 | 
            +
                elif isinstance(val, time.struct_time):
         | 
| 13 | 
            +
                    timestamp = time.mktime(val)
         | 
| 14 | 
            +
                elif isinstance(val, datetime.datetime):
         | 
| 15 | 
            +
                    timestamp = val.timestamp()
         | 
| 16 | 
            +
                elif isinstance(val, datetime.date):
         | 
| 17 | 
            +
                    dt = datetime.datetime.combine(val, datetime.time())
         | 
| 18 | 
            +
                    timestamp = dt.timestamp()
         | 
| 19 | 
            +
                elif isinstance(val, str):
         | 
| 20 | 
            +
                    try:
         | 
| 21 | 
            +
                        # The full format looks like 'YYYY-MM-DD HH:MM:SS.mmmmmm'.
         | 
| 22 | 
            +
                        dt = datetime.datetime.fromisoformat(val)
         | 
| 23 | 
            +
                        timestamp = dt.timestamp()
         | 
| 24 | 
            +
                    except:
         | 
| 25 | 
            +
                        raise TypeError('when argument is str, it should conform to isoformat')
         | 
| 26 | 
            +
                else:
         | 
| 27 | 
            +
                    raise TypeError('unsupported type!')
         | 
| 28 | 
            +
                timestamp = timestamp * multiplier
         | 
| 29 | 
            +
                if rounded:
         | 
| 30 | 
            +
                    # The return value is an integer if ndigits is omitted or None.
         | 
| 31 | 
            +
                    timestamp = round(timestamp)
         | 
| 32 | 
            +
                return timestamp
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def get_timestamp(time_val=None, rounded=True):
         | 
| 36 | 
            +
                """timestamp in seconds.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                return _to_timestamp(time_val, multiplier=1, rounded=rounded)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def get_timestamp_ms(time_val=None, rounded=True):
         | 
| 42 | 
            +
                """timestamp in milliseconds.
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                return _to_timestamp(time_val, multiplier=1000, rounded=rounded)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def get_timestamp_us(time_val=None, rounded=True):
         | 
| 48 | 
            +
                """timestamp in microseconds.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                return _to_timestamp(time_val, multiplier=1000000, rounded=rounded)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_utc8now() -> datetime.datetime:
         | 
| 54 | 
            +
                """get current UTC-8 time or Beijing time
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                tz = datetime.timezone(datetime.timedelta(hours=8))
         | 
| 57 | 
            +
                utc8now = datetime.datetime.now(tz)
         | 
| 58 | 
            +
                return utc8now
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class ContextTimer(object):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                References:
         | 
| 64 | 
            +
                    WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                def __init__(self, name=None, use_log=False, quiet=False):
         | 
| 67 | 
            +
                    self.use_log = use_log
         | 
| 68 | 
            +
                    self.quiet = quiet
         | 
| 69 | 
            +
                    if name is None:
         | 
| 70 | 
            +
                        self.name = ''
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.name = '{}, '.format(name.rstrip())
         | 
| 73 | 
            +
                            
         | 
| 74 | 
            +
                def __enter__(self):
         | 
| 75 | 
            +
                    self.start_time = time.time()
         | 
| 76 | 
            +
                    if not self.quiet:
         | 
| 77 | 
            +
                        self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
         | 
| 78 | 
            +
                    return self
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                def __exit__(self, exc_type, exc_val, exc_tb):
         | 
| 81 | 
            +
                    if not self.quiet:
         | 
| 82 | 
            +
                        self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
         | 
| 83 | 
            +
                        self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
         | 
| 84 | 
            +
                        
         | 
| 85 | 
            +
                @property
         | 
| 86 | 
            +
                def _now_time_str(self):
         | 
| 87 | 
            +
                    return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                def _print_or_log(self, output_str):
         | 
| 90 | 
            +
                    if self.use_log:
         | 
| 91 | 
            +
                        logging.info(output_str)
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        print(output_str)
         | 
| 94 | 
            +
                        
         | 
| 95 | 
            +
                def get_eplased_time(self):
         | 
| 96 | 
            +
                    return time.time() - self.start_time
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                def enter(self):
         | 
| 99 | 
            +
                    """Manually trigger enter"""
         | 
| 100 | 
            +
                    self.__enter__()
         | 
| 101 | 
            +
             | 
    	
        khandy/version.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __version__ = '0.1.8'
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = ['__version__']
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            opencv-python>=4.5
         | 
| 2 | 
            +
            numpy>=1.11.1
         | 
| 3 | 
            +
            lxml
         | 
| 4 | 
            +
            requests
         | 
| 5 | 
            +
            onnxruntime
         | 
| 6 | 
            +
            Pillow
         | 
| 7 | 
            +
            modelscope==1.15
         | 
