Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| cimport cython | |
| cimport libcpp | |
| cimport libcpp.pair | |
| cimport libcpp.queue | |
| cimport numpy as np | |
| from libcpp.pair cimport * | |
| from libcpp.queue cimport * | |
| @cython.boundscheck(False) | |
| @cython.wraparound(False) | |
| cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, | |
| np.ndarray[np.int32_t, ndim=2] label, | |
| int kernel_num, | |
| int label_num, | |
| float min_area=0): | |
| cdef np.ndarray[np.int32_t, ndim=2] pred | |
| pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) | |
| for label_idx in range(1, label_num): | |
| if np.sum(label == label_idx) < min_area: | |
| label[label == label_idx] = 0 | |
| cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ | |
| queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() | |
| cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ | |
| queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() | |
| cdef np.int16_t* dx = [-1, 1, 0, 0] | |
| cdef np.int16_t* dy = [0, 0, -1, 1] | |
| cdef np.int16_t tmpx, tmpy | |
| points = np.array(np.where(label > 0)).transpose((1, 0)) | |
| for point_idx in range(points.shape[0]): | |
| tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] | |
| que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) | |
| pred[tmpx, tmpy] = label[tmpx, tmpy] | |
| cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur | |
| cdef int cur_label | |
| for kernel_idx in range(kernel_num - 1, -1, -1): | |
| while not que.empty(): | |
| cur = que.front() | |
| que.pop() | |
| cur_label = pred[cur.first, cur.second] | |
| is_edge = True | |
| for j in range(4): | |
| tmpx = cur.first + dx[j] | |
| tmpy = cur.second + dy[j] | |
| if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: | |
| continue | |
| if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: | |
| continue | |
| que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) | |
| pred[tmpx, tmpy] = cur_label | |
| is_edge = False | |
| if is_edge: | |
| nxt_que.push(cur) | |
| que, nxt_que = nxt_que, que | |
| return pred | |
| def pse(kernels, min_area): | |
| kernel_num = kernels.shape[0] | |
| label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) | |
| return _pse(kernels[:-1], label, kernel_num, label_num, min_area) |