Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| 3D point cloud augmentation | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import random | |
| import numbers | |
| import scipy | |
| import scipy.ndimage | |
| import scipy.interpolate | |
| import scipy.stats | |
| import numpy as np | |
| import torch | |
| import copy | |
| from collections.abc import Sequence, Mapping | |
| from .registry import Registry | |
| TRANSFORMS = Registry("transforms") | |
| def index_operator(data_dict, index, duplicate=False): | |
| # index selection operator for keys in "index_valid_keys" | |
| # custom these keys by "Update" transform in config | |
| if "index_valid_keys" not in data_dict: | |
| data_dict["index_valid_keys"] = [ | |
| "coord", | |
| "color", | |
| "normal", | |
| "strength", | |
| "segment", | |
| "instance", | |
| ] | |
| if not duplicate: | |
| for key in data_dict["index_valid_keys"]: | |
| if key in data_dict: | |
| data_dict[key] = data_dict[key][index] | |
| return data_dict | |
| else: | |
| data_dict_ = dict() | |
| for key in data_dict.keys(): | |
| if key in data_dict["index_valid_keys"]: | |
| data_dict_[key] = data_dict[key][index] | |
| else: | |
| data_dict_[key] = data_dict[key] | |
| return data_dict_ | |
| class Collect(object): | |
| def __init__(self, keys, offset_keys_dict=None, **kwargs): | |
| """ | |
| e.g. Collect(keys=[coord], feat_keys=[coord, color]) | |
| """ | |
| if offset_keys_dict is None: | |
| offset_keys_dict = dict(offset="coord") | |
| self.keys = keys | |
| self.offset_keys = offset_keys_dict | |
| self.kwargs = kwargs | |
| def __call__(self, data_dict): | |
| data = dict() | |
| if isinstance(self.keys, str): | |
| self.keys = [self.keys] | |
| for key in self.keys: | |
| data[key] = data_dict[key] | |
| for key, value in self.offset_keys.items(): | |
| data[key] = torch.tensor([data_dict[value].shape[0]]) | |
| for name, keys in self.kwargs.items(): | |
| name = name.replace("_keys", "") | |
| assert isinstance(keys, Sequence) | |
| data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1) | |
| return data | |
| class Copy(object): | |
| def __init__(self, keys_dict=None): | |
| if keys_dict is None: | |
| keys_dict = dict(coord="origin_coord", segment="origin_segment") | |
| self.keys_dict = keys_dict | |
| def __call__(self, data_dict): | |
| for key, value in self.keys_dict.items(): | |
| if isinstance(data_dict[key], np.ndarray): | |
| data_dict[value] = data_dict[key].copy() | |
| elif isinstance(data_dict[key], torch.Tensor): | |
| data_dict[value] = data_dict[key].clone().detach() | |
| else: | |
| data_dict[value] = copy.deepcopy(data_dict[key]) | |
| return data_dict | |
| class Update(object): | |
| def __init__(self, keys_dict=None): | |
| if keys_dict is None: | |
| keys_dict = dict() | |
| self.keys_dict = keys_dict | |
| def __call__(self, data_dict): | |
| for key, value in self.keys_dict.items(): | |
| data_dict[key] = value | |
| return data_dict | |
| class ToTensor(object): | |
| def __call__(self, data): | |
| if isinstance(data, torch.Tensor): | |
| return data | |
| elif isinstance(data, str): | |
| # note that str is also a kind of sequence, judgement should before sequence | |
| return data | |
| elif isinstance(data, int): | |
| return torch.LongTensor([data]) | |
| elif isinstance(data, float): | |
| return torch.FloatTensor([data]) | |
| elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool): | |
| return torch.from_numpy(data) | |
| elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer): | |
| return torch.from_numpy(data).long() | |
| elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating): | |
| return torch.from_numpy(data).float() | |
| elif isinstance(data, Mapping): | |
| result = {sub_key: self(item) for sub_key, item in data.items()} | |
| return result | |
| elif isinstance(data, Sequence): | |
| result = [self(item) for item in data] | |
| return result | |
| else: | |
| raise TypeError(f"type {type(data)} cannot be converted to tensor.") | |
| class NormalizeColor(object): | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys(): | |
| data_dict["color"] = data_dict["color"] / 255 | |
| return data_dict | |
| class NormalizeCoord(object): | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| # modified from pointnet2 | |
| centroid = np.mean(data_dict["coord"], axis=0) | |
| data_dict["coord"] -= centroid | |
| m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1))) | |
| data_dict["coord"] = data_dict["coord"] / m | |
| return data_dict | |
| class PositiveShift(object): | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| coord_min = np.min(data_dict["coord"], 0) | |
| data_dict["coord"] -= coord_min | |
| return data_dict | |
| class CenterShift(object): | |
| def __init__(self, apply_z=True): | |
| self.apply_z = apply_z | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| x_min, y_min, z_min = data_dict["coord"].min(axis=0) | |
| x_max, y_max, _ = data_dict["coord"].max(axis=0) | |
| if self.apply_z: | |
| shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min] | |
| else: | |
| shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0] | |
| data_dict["coord"] -= shift | |
| return data_dict | |
| class RandomShift(object): | |
| def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))): | |
| self.shift = shift | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1]) | |
| shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1]) | |
| shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1]) | |
| data_dict["coord"] += [shift_x, shift_y, shift_z] | |
| return data_dict | |
| class PointClip(object): | |
| def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)): | |
| self.point_cloud_range = point_cloud_range | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| data_dict["coord"] = np.clip( | |
| data_dict["coord"], | |
| a_min=self.point_cloud_range[:3], | |
| a_max=self.point_cloud_range[3:], | |
| ) | |
| return data_dict | |
| class RandomDropout(object): | |
| def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): | |
| """ | |
| upright_axis: axis index among x,y,z, i.e. 2 for z | |
| """ | |
| self.dropout_ratio = dropout_ratio | |
| self.dropout_application_ratio = dropout_application_ratio | |
| def __call__(self, data_dict): | |
| if random.random() < self.dropout_application_ratio: | |
| n = len(data_dict["coord"]) | |
| idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False) | |
| if "sampled_index" in data_dict: | |
| # for ScanNet data efficient, we need to make sure labeled point is sampled. | |
| idx = np.unique(np.append(idx, data_dict["sampled_index"])) | |
| mask = np.zeros_like(data_dict["segment"]).astype(bool) | |
| mask[data_dict["sampled_index"]] = True | |
| data_dict["sampled_index"] = np.where(mask[idx])[0] | |
| data_dict = index_operator(data_dict, idx) | |
| return data_dict | |
| class RandomRotate(object): | |
| def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5): | |
| self.angle = [-1, 1] if angle is None else angle | |
| self.axis = axis | |
| self.always_apply = always_apply | |
| self.p = p if not self.always_apply else 1 | |
| self.center = center | |
| def __call__(self, data_dict): | |
| if random.random() > self.p: | |
| return data_dict | |
| angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi | |
| rot_cos, rot_sin = np.cos(angle), np.sin(angle) | |
| if self.axis == "x": | |
| rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) | |
| elif self.axis == "y": | |
| rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) | |
| elif self.axis == "z": | |
| rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) | |
| else: | |
| raise NotImplementedError | |
| if "coord" in data_dict.keys(): | |
| if self.center is None: | |
| x_min, y_min, z_min = data_dict["coord"].min(axis=0) | |
| x_max, y_max, z_max = data_dict["coord"].max(axis=0) | |
| center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] | |
| else: | |
| center = self.center | |
| data_dict["coord"] -= center | |
| data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) | |
| data_dict["coord"] += center | |
| if "normal" in data_dict.keys(): | |
| data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) | |
| return data_dict | |
| class RandomRotateTargetAngle(object): | |
| def __init__( | |
| self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75 | |
| ): | |
| self.angle = angle | |
| self.axis = axis | |
| self.always_apply = always_apply | |
| self.p = p if not self.always_apply else 1 | |
| self.center = center | |
| def __call__(self, data_dict): | |
| if random.random() > self.p: | |
| return data_dict | |
| angle = np.random.choice(self.angle) * np.pi | |
| rot_cos, rot_sin = np.cos(angle), np.sin(angle) | |
| if self.axis == "x": | |
| rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) | |
| elif self.axis == "y": | |
| rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) | |
| elif self.axis == "z": | |
| rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) | |
| else: | |
| raise NotImplementedError | |
| if "coord" in data_dict.keys(): | |
| if self.center is None: | |
| x_min, y_min, z_min = data_dict["coord"].min(axis=0) | |
| x_max, y_max, z_max = data_dict["coord"].max(axis=0) | |
| center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] | |
| else: | |
| center = self.center | |
| data_dict["coord"] -= center | |
| data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) | |
| data_dict["coord"] += center | |
| if "normal" in data_dict.keys(): | |
| data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) | |
| return data_dict | |
| class RandomScale(object): | |
| def __init__(self, scale=None, anisotropic=False): | |
| self.scale = scale if scale is not None else [0.95, 1.05] | |
| self.anisotropic = anisotropic | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| scale = np.random.uniform( | |
| self.scale[0], self.scale[1], 3 if self.anisotropic else 1 | |
| ) | |
| data_dict["coord"] *= scale | |
| return data_dict | |
| class RandomFlip(object): | |
| def __init__(self, p=0.5): | |
| self.p = p | |
| def __call__(self, data_dict): | |
| if np.random.rand() < self.p: | |
| if "coord" in data_dict.keys(): | |
| data_dict["coord"][:, 0] = -data_dict["coord"][:, 0] | |
| if "normal" in data_dict.keys(): | |
| data_dict["normal"][:, 0] = -data_dict["normal"][:, 0] | |
| if np.random.rand() < self.p: | |
| if "coord" in data_dict.keys(): | |
| data_dict["coord"][:, 1] = -data_dict["coord"][:, 1] | |
| if "normal" in data_dict.keys(): | |
| data_dict["normal"][:, 1] = -data_dict["normal"][:, 1] | |
| return data_dict | |
| class RandomJitter(object): | |
| def __init__(self, sigma=0.01, clip=0.05): | |
| assert clip > 0 | |
| self.sigma = sigma | |
| self.clip = clip | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| jitter = np.clip( | |
| self.sigma * np.random.randn(data_dict["coord"].shape[0], 3), | |
| -self.clip, | |
| self.clip, | |
| ) | |
| data_dict["coord"] += jitter | |
| return data_dict | |
| class ClipGaussianJitter(object): | |
| def __init__(self, scalar=0.02, store_jitter=False): | |
| self.scalar = scalar | |
| self.mean = np.mean(3) | |
| self.cov = np.identity(3) | |
| self.quantile = 1.96 | |
| self.store_jitter = store_jitter | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys(): | |
| jitter = np.random.multivariate_normal( | |
| self.mean, self.cov, data_dict["coord"].shape[0] | |
| ) | |
| jitter = self.scalar * np.clip(jitter / 1.96, -1, 1) | |
| data_dict["coord"] += jitter | |
| if self.store_jitter: | |
| data_dict["jitter"] = jitter | |
| return data_dict | |
| class ChromaticAutoContrast(object): | |
| def __init__(self, p=0.2, blend_factor=None): | |
| self.p = p | |
| self.blend_factor = blend_factor | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys() and np.random.rand() < self.p: | |
| lo = np.min(data_dict["color"], 0, keepdims=True) | |
| hi = np.max(data_dict["color"], 0, keepdims=True) | |
| scale = 255 / (hi - lo) | |
| contrast_feat = (data_dict["color"][:, :3] - lo) * scale | |
| blend_factor = ( | |
| np.random.rand() if self.blend_factor is None else self.blend_factor | |
| ) | |
| data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][ | |
| :, :3 | |
| ] + blend_factor * contrast_feat | |
| return data_dict | |
| class ChromaticTranslation(object): | |
| def __init__(self, p=0.95, ratio=0.05): | |
| self.p = p | |
| self.ratio = ratio | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys() and np.random.rand() < self.p: | |
| tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio | |
| data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255) | |
| return data_dict | |
| class ChromaticJitter(object): | |
| def __init__(self, p=0.95, std=0.005): | |
| self.p = p | |
| self.std = std | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys() and np.random.rand() < self.p: | |
| noise = np.random.randn(data_dict["color"].shape[0], 3) | |
| noise *= self.std * 255 | |
| data_dict["color"][:, :3] = np.clip( | |
| noise + data_dict["color"][:, :3], 0, 255 | |
| ) | |
| return data_dict | |
| class RandomColorGrayScale(object): | |
| def __init__(self, p): | |
| self.p = p | |
| def rgb_to_grayscale(color, num_output_channels=1): | |
| if color.shape[-1] < 3: | |
| raise TypeError( | |
| "Input color should have at least 3 dimensions, but found {}".format( | |
| color.shape[-1] | |
| ) | |
| ) | |
| if num_output_channels not in (1, 3): | |
| raise ValueError("num_output_channels should be either 1 or 3") | |
| r, g, b = color[..., 0], color[..., 1], color[..., 2] | |
| gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype) | |
| gray = np.expand_dims(gray, axis=-1) | |
| if num_output_channels == 3: | |
| gray = np.broadcast_to(gray, color.shape) | |
| return gray | |
| def __call__(self, data_dict): | |
| if np.random.rand() < self.p: | |
| data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3) | |
| return data_dict | |
| class RandomColorJitter(object): | |
| """ | |
| Random Color Jitter for 3D point cloud (refer torchvision) | |
| """ | |
| def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95): | |
| self.brightness = self._check_input(brightness, "brightness") | |
| self.contrast = self._check_input(contrast, "contrast") | |
| self.saturation = self._check_input(saturation, "saturation") | |
| self.hue = self._check_input( | |
| hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False | |
| ) | |
| self.p = p | |
| def _check_input( | |
| value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True | |
| ): | |
| if isinstance(value, numbers.Number): | |
| if value < 0: | |
| raise ValueError( | |
| "If {} is a single number, it must be non negative.".format(name) | |
| ) | |
| value = [center - float(value), center + float(value)] | |
| if clip_first_on_zero: | |
| value[0] = max(value[0], 0.0) | |
| elif isinstance(value, (tuple, list)) and len(value) == 2: | |
| if not bound[0] <= value[0] <= value[1] <= bound[1]: | |
| raise ValueError("{} values should be between {}".format(name, bound)) | |
| else: | |
| raise TypeError( | |
| "{} should be a single number or a list/tuple with length 2.".format( | |
| name | |
| ) | |
| ) | |
| # if value is 0 or (1., 1.) for brightness/contrast/saturation | |
| # or (0., 0.) for hue, do nothing | |
| if value[0] == value[1] == center: | |
| value = None | |
| return value | |
| def blend(color1, color2, ratio): | |
| ratio = float(ratio) | |
| bound = 255.0 | |
| return ( | |
| (ratio * color1 + (1.0 - ratio) * color2) | |
| .clip(0, bound) | |
| .astype(color1.dtype) | |
| ) | |
| def rgb2hsv(rgb): | |
| r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] | |
| maxc = np.max(rgb, axis=-1) | |
| minc = np.min(rgb, axis=-1) | |
| eqc = maxc == minc | |
| cr = maxc - minc | |
| s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc)) | |
| cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc) | |
| rc = (maxc - r) / cr_divisor | |
| gc = (maxc - g) / cr_divisor | |
| bc = (maxc - b) / cr_divisor | |
| hr = (maxc == r) * (bc - gc) | |
| hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) | |
| hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) | |
| h = hr + hg + hb | |
| h = (h / 6.0 + 1.0) % 1.0 | |
| return np.stack((h, s, maxc), axis=-1) | |
| def hsv2rgb(hsv): | |
| h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] | |
| i = np.floor(h * 6.0) | |
| f = (h * 6.0) - i | |
| i = i.astype(np.int32) | |
| p = np.clip((v * (1.0 - s)), 0.0, 1.0) | |
| q = np.clip((v * (1.0 - s * f)), 0.0, 1.0) | |
| t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) | |
| i = i % 6 | |
| mask = np.expand_dims(i, axis=-1) == np.arange(6) | |
| a1 = np.stack((v, q, p, p, t, v), axis=-1) | |
| a2 = np.stack((t, v, v, q, p, p), axis=-1) | |
| a3 = np.stack((p, p, t, v, v, q), axis=-1) | |
| a4 = np.stack((a1, a2, a3), axis=-1) | |
| return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4) | |
| def adjust_brightness(self, color, brightness_factor): | |
| if brightness_factor < 0: | |
| raise ValueError( | |
| "brightness_factor ({}) is not non-negative.".format(brightness_factor) | |
| ) | |
| return self.blend(color, np.zeros_like(color), brightness_factor) | |
| def adjust_contrast(self, color, contrast_factor): | |
| if contrast_factor < 0: | |
| raise ValueError( | |
| "contrast_factor ({}) is not non-negative.".format(contrast_factor) | |
| ) | |
| mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color)) | |
| return self.blend(color, mean, contrast_factor) | |
| def adjust_saturation(self, color, saturation_factor): | |
| if saturation_factor < 0: | |
| raise ValueError( | |
| "saturation_factor ({}) is not non-negative.".format(saturation_factor) | |
| ) | |
| gray = RandomColorGrayScale.rgb_to_grayscale(color) | |
| return self.blend(color, gray, saturation_factor) | |
| def adjust_hue(self, color, hue_factor): | |
| if not (-0.5 <= hue_factor <= 0.5): | |
| raise ValueError( | |
| "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor) | |
| ) | |
| orig_dtype = color.dtype | |
| hsv = self.rgb2hsv(color / 255.0) | |
| h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] | |
| h = (h + hue_factor) % 1.0 | |
| hsv = np.stack((h, s, v), axis=-1) | |
| color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype) | |
| return color_hue_adj | |
| def get_params(brightness, contrast, saturation, hue): | |
| fn_idx = torch.randperm(4) | |
| b = ( | |
| None | |
| if brightness is None | |
| else np.random.uniform(brightness[0], brightness[1]) | |
| ) | |
| c = None if contrast is None else np.random.uniform(contrast[0], contrast[1]) | |
| s = ( | |
| None | |
| if saturation is None | |
| else np.random.uniform(saturation[0], saturation[1]) | |
| ) | |
| h = None if hue is None else np.random.uniform(hue[0], hue[1]) | |
| return fn_idx, b, c, s, h | |
| def __call__(self, data_dict): | |
| ( | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) | |
| for fn_id in fn_idx: | |
| if ( | |
| fn_id == 0 | |
| and brightness_factor is not None | |
| and np.random.rand() < self.p | |
| ): | |
| data_dict["color"] = self.adjust_brightness( | |
| data_dict["color"], brightness_factor | |
| ) | |
| elif ( | |
| fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p | |
| ): | |
| data_dict["color"] = self.adjust_contrast( | |
| data_dict["color"], contrast_factor | |
| ) | |
| elif ( | |
| fn_id == 2 | |
| and saturation_factor is not None | |
| and np.random.rand() < self.p | |
| ): | |
| data_dict["color"] = self.adjust_saturation( | |
| data_dict["color"], saturation_factor | |
| ) | |
| elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p: | |
| data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor) | |
| return data_dict | |
| class HueSaturationTranslation(object): | |
| def rgb_to_hsv(rgb): | |
| # Translated from source of colorsys.rgb_to_hsv | |
| # r,g,b should be a numpy arrays with values between 0 and 255 | |
| # rgb_to_hsv returns an array of floats between 0.0 and 1.0. | |
| rgb = rgb.astype("float") | |
| hsv = np.zeros_like(rgb) | |
| # in case an RGBA array was passed, just copy the A channel | |
| hsv[..., 3:] = rgb[..., 3:] | |
| r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] | |
| maxc = np.max(rgb[..., :3], axis=-1) | |
| minc = np.min(rgb[..., :3], axis=-1) | |
| hsv[..., 2] = maxc | |
| mask = maxc != minc | |
| hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] | |
| rc = np.zeros_like(r) | |
| gc = np.zeros_like(g) | |
| bc = np.zeros_like(b) | |
| rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] | |
| gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] | |
| bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] | |
| hsv[..., 0] = np.select( | |
| [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc | |
| ) | |
| hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 | |
| return hsv | |
| def hsv_to_rgb(hsv): | |
| # Translated from source of colorsys.hsv_to_rgb | |
| # h,s should be a numpy arrays with values between 0.0 and 1.0 | |
| # v should be a numpy array with values between 0.0 and 255.0 | |
| # hsv_to_rgb returns an array of uints between 0 and 255. | |
| rgb = np.empty_like(hsv) | |
| rgb[..., 3:] = hsv[..., 3:] | |
| h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] | |
| i = (h * 6.0).astype("uint8") | |
| f = (h * 6.0) - i | |
| p = v * (1.0 - s) | |
| q = v * (1.0 - s * f) | |
| t = v * (1.0 - s * (1.0 - f)) | |
| i = i % 6 | |
| conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] | |
| rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) | |
| rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) | |
| rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) | |
| return rgb.astype("uint8") | |
| def __init__(self, hue_max=0.5, saturation_max=0.2): | |
| self.hue_max = hue_max | |
| self.saturation_max = saturation_max | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys(): | |
| # Assume color[:, :3] is rgb | |
| hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3]) | |
| hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max | |
| sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max | |
| hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) | |
| hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) | |
| data_dict["color"][:, :3] = np.clip( | |
| HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255 | |
| ) | |
| return data_dict | |
| class RandomColorDrop(object): | |
| def __init__(self, p=0.2, color_augment=0.0): | |
| self.p = p | |
| self.color_augment = color_augment | |
| def __call__(self, data_dict): | |
| if "color" in data_dict.keys() and np.random.rand() < self.p: | |
| data_dict["color"] *= self.color_augment | |
| return data_dict | |
| def __repr__(self): | |
| return "RandomColorDrop(color_augment: {}, p: {})".format( | |
| self.color_augment, self.p | |
| ) | |
| class ElasticDistortion(object): | |
| def __init__(self, distortion_params=None): | |
| self.distortion_params = ( | |
| [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params | |
| ) | |
| def elastic_distortion(coords, granularity, magnitude): | |
| """ | |
| Apply elastic distortion on sparse coordinate space. | |
| pointcloud: numpy array of (number of points, at least 3 spatial dims) | |
| granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) | |
| magnitude: noise multiplier | |
| """ | |
| blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3 | |
| blury = np.ones((1, 3, 1, 1)).astype("float32") / 3 | |
| blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3 | |
| coords_min = coords.min(0) | |
| # Create Gaussian noise tensor of the size given by granularity. | |
| noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 | |
| noise = np.random.randn(*noise_dim, 3).astype(np.float32) | |
| # Smoothing. | |
| for _ in range(2): | |
| noise = scipy.ndimage.filters.convolve( | |
| noise, blurx, mode="constant", cval=0 | |
| ) | |
| noise = scipy.ndimage.filters.convolve( | |
| noise, blury, mode="constant", cval=0 | |
| ) | |
| noise = scipy.ndimage.filters.convolve( | |
| noise, blurz, mode="constant", cval=0 | |
| ) | |
| # Trilinear interpolate noise filters for each spatial dimensions. | |
| ax = [ | |
| np.linspace(d_min, d_max, d) | |
| for d_min, d_max, d in zip( | |
| coords_min - granularity, | |
| coords_min + granularity * (noise_dim - 2), | |
| noise_dim, | |
| ) | |
| ] | |
| interp = scipy.interpolate.RegularGridInterpolator( | |
| ax, noise, bounds_error=False, fill_value=0 | |
| ) | |
| coords += interp(coords) * magnitude | |
| return coords | |
| def __call__(self, data_dict): | |
| if "coord" in data_dict.keys() and self.distortion_params is not None: | |
| if random.random() < 0.95: | |
| for granularity, magnitude in self.distortion_params: | |
| data_dict["coord"] = self.elastic_distortion( | |
| data_dict["coord"], granularity, magnitude | |
| ) | |
| return data_dict | |
| class GridSample(object): | |
| def __init__( | |
| self, | |
| grid_size=0.05, | |
| hash_type="fnv", | |
| mode="train", | |
| return_inverse=False, | |
| return_grid_coord=False, | |
| return_min_coord=False, | |
| return_displacement=False, | |
| project_displacement=False, | |
| ): | |
| self.grid_size = grid_size | |
| self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec | |
| assert mode in ["train", "test"] | |
| self.mode = mode | |
| self.return_inverse = return_inverse | |
| self.return_grid_coord = return_grid_coord | |
| self.return_min_coord = return_min_coord | |
| self.return_displacement = return_displacement | |
| self.project_displacement = project_displacement | |
| def __call__(self, data_dict): | |
| assert "coord" in data_dict.keys() | |
| scaled_coord = data_dict["coord"] / np.array(self.grid_size) | |
| grid_coord = np.floor(scaled_coord).astype(int) | |
| min_coord = grid_coord.min(0) | |
| grid_coord -= min_coord | |
| scaled_coord -= min_coord | |
| min_coord = min_coord * np.array(self.grid_size) | |
| key = self.hash(grid_coord) | |
| idx_sort = np.argsort(key) | |
| key_sort = key[idx_sort] | |
| _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) | |
| if self.mode == "train": # train mode | |
| idx_select = ( | |
| np.cumsum(np.insert(count, 0, 0)[0:-1]) | |
| + np.random.randint(0, count.max(), count.size) % count | |
| ) | |
| idx_unique = idx_sort[idx_select] | |
| if "sampled_index" in data_dict: | |
| # for ScanNet data efficient, we need to make sure labeled point is sampled. | |
| idx_unique = np.unique( | |
| np.append(idx_unique, data_dict["sampled_index"]) | |
| ) | |
| mask = np.zeros_like(data_dict["segment"]).astype(bool) | |
| mask[data_dict["sampled_index"]] = True | |
| data_dict["sampled_index"] = np.where(mask[idx_unique])[0] | |
| data_dict = index_operator(data_dict, idx_unique) | |
| if self.return_inverse: | |
| data_dict["inverse"] = np.zeros_like(inverse) | |
| data_dict["inverse"][idx_sort] = inverse | |
| if self.return_grid_coord: | |
| data_dict["grid_coord"] = grid_coord[idx_unique] | |
| data_dict["index_valid_keys"].append("grid_coord") | |
| if self.return_min_coord: | |
| data_dict["min_coord"] = min_coord.reshape([1, 3]) | |
| if self.return_displacement: | |
| displacement = ( | |
| scaled_coord - grid_coord - 0.5 | |
| ) # [0, 1] -> [-0.5, 0.5] displacement to center | |
| if self.project_displacement: | |
| displacement = np.sum( | |
| displacement * data_dict["normal"], axis=-1, keepdims=True | |
| ) | |
| data_dict["displacement"] = displacement[idx_unique] | |
| data_dict["index_valid_keys"].append("displacement") | |
| return data_dict | |
| elif self.mode == "test": # test mode | |
| data_part_list = [] | |
| for i in range(count.max()): | |
| idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count | |
| idx_part = idx_sort[idx_select] | |
| data_part = index_operator(data_dict, idx_part, duplicate=True) | |
| data_part["index"] = idx_part | |
| if self.return_inverse: | |
| data_part["inverse"] = np.zeros_like(inverse) | |
| data_part["inverse"][idx_sort] = inverse | |
| if self.return_grid_coord: | |
| data_part["grid_coord"] = grid_coord[idx_part] | |
| data_dict["index_valid_keys"].append("grid_coord") | |
| if self.return_min_coord: | |
| data_part["min_coord"] = min_coord.reshape([1, 3]) | |
| if self.return_displacement: | |
| displacement = ( | |
| scaled_coord - grid_coord - 0.5 | |
| ) # [0, 1] -> [-0.5, 0.5] displacement to center | |
| if self.project_displacement: | |
| displacement = np.sum( | |
| displacement * data_dict["normal"], axis=-1, keepdims=True | |
| ) | |
| data_dict["displacement"] = displacement[idx_part] | |
| data_dict["index_valid_keys"].append("displacement") | |
| data_part_list.append(data_part) | |
| return data_part_list | |
| else: | |
| raise NotImplementedError | |
| def ravel_hash_vec(arr): | |
| """ | |
| Ravel the coordinates after subtracting the min coordinates. | |
| """ | |
| assert arr.ndim == 2 | |
| arr = arr.copy() | |
| arr -= arr.min(0) | |
| arr = arr.astype(np.uint64, copy=False) | |
| arr_max = arr.max(0).astype(np.uint64) + 1 | |
| keys = np.zeros(arr.shape[0], dtype=np.uint64) | |
| # Fortran style indexing | |
| for j in range(arr.shape[1] - 1): | |
| keys += arr[:, j] | |
| keys *= arr_max[j + 1] | |
| keys += arr[:, -1] | |
| return keys | |
| def fnv_hash_vec(arr): | |
| """ | |
| FNV64-1A | |
| """ | |
| assert arr.ndim == 2 | |
| # Floor first for negative coordinates | |
| arr = arr.copy() | |
| arr = arr.astype(np.uint64, copy=False) | |
| hashed_arr = np.uint64(14695981039346656037) * np.ones( | |
| arr.shape[0], dtype=np.uint64 | |
| ) | |
| for j in range(arr.shape[1]): | |
| hashed_arr *= np.uint64(1099511628211) | |
| hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) | |
| return hashed_arr | |
| class SphereCrop(object): | |
| def __init__(self, point_max=80000, sample_rate=None, mode="random"): | |
| self.point_max = point_max | |
| self.sample_rate = sample_rate | |
| assert mode in ["random", "center", "all"] | |
| self.mode = mode | |
| def __call__(self, data_dict): | |
| point_max = ( | |
| int(self.sample_rate * data_dict["coord"].shape[0]) | |
| if self.sample_rate is not None | |
| else self.point_max | |
| ) | |
| assert "coord" in data_dict.keys() | |
| if data_dict["coord"].shape[0] > point_max: | |
| if self.mode == "random": | |
| center = data_dict["coord"][ | |
| np.random.randint(data_dict["coord"].shape[0]) | |
| ] | |
| elif self.mode == "center": | |
| center = data_dict["coord"][data_dict["coord"].shape[0] // 2] | |
| else: | |
| raise NotImplementedError | |
| idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[ | |
| :point_max | |
| ] | |
| data_dict = index_operator(data_dict, idx_crop) | |
| return data_dict | |
| class ShufflePoint(object): | |
| def __call__(self, data_dict): | |
| assert "coord" in data_dict.keys() | |
| shuffle_index = np.arange(data_dict["coord"].shape[0]) | |
| np.random.shuffle(shuffle_index) | |
| data_dict = index_operator(data_dict, shuffle_index) | |
| return data_dict | |
| class CropBoundary(object): | |
| def __call__(self, data_dict): | |
| assert "segment" in data_dict | |
| segment = data_dict["segment"].flatten() | |
| mask = (segment != 0) * (segment != 1) | |
| data_dict = index_operator(data_dict, mask) | |
| return data_dict | |
| class ContrastiveViewsGenerator(object): | |
| def __init__( | |
| self, | |
| view_keys=("coord", "color", "normal", "origin_coord"), | |
| view_trans_cfg=None, | |
| ): | |
| self.view_keys = view_keys | |
| self.view_trans = Compose(view_trans_cfg) | |
| def __call__(self, data_dict): | |
| view1_dict = dict() | |
| view2_dict = dict() | |
| for key in self.view_keys: | |
| view1_dict[key] = data_dict[key].copy() | |
| view2_dict[key] = data_dict[key].copy() | |
| view1_dict = self.view_trans(view1_dict) | |
| view2_dict = self.view_trans(view2_dict) | |
| for key, value in view1_dict.items(): | |
| data_dict["view1_" + key] = value | |
| for key, value in view2_dict.items(): | |
| data_dict["view2_" + key] = value | |
| return data_dict | |
| class MultiViewGenerator(object): | |
| def __init__( | |
| self, | |
| global_view_num=2, | |
| global_view_scale=(0.4, 1.0), | |
| local_view_num=4, | |
| local_view_scale=(0.1, 0.4), | |
| global_shared_transform=None, | |
| global_transform=None, | |
| local_transform=None, | |
| max_size=65536, | |
| center_height_scale=(0, 1), | |
| shared_global_view=False, | |
| view_keys=("coord", "origin_coord", "color", "normal"), | |
| ): | |
| self.global_view_num = global_view_num | |
| self.global_view_scale = global_view_scale | |
| self.local_view_num = local_view_num | |
| self.local_view_scale = local_view_scale | |
| self.global_shared_transform = Compose(global_shared_transform) | |
| self.global_transform = Compose(global_transform) | |
| self.local_transform = Compose(local_transform) | |
| self.max_size = max_size | |
| self.center_height_scale = center_height_scale | |
| self.shared_global_view = shared_global_view | |
| self.view_keys = view_keys | |
| assert "coord" in view_keys | |
| def get_view(self, point, center, scale): | |
| coord = point["coord"] | |
| max_size = min(self.max_size, coord.shape[0]) | |
| size = int(np.random.uniform(*scale) * max_size) | |
| index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size] | |
| view = dict(index=index) | |
| for key in point.keys(): | |
| if key in self.view_keys: | |
| view[key] = point[key][index] | |
| if "index_valid_keys" in point.keys(): | |
| # inherit index_valid_keys from point | |
| view["index_valid_keys"] = point["index_valid_keys"] | |
| return view | |
| def __call__(self, data_dict): | |
| coord = data_dict["coord"] | |
| point = self.global_shared_transform(copy.deepcopy(data_dict)) | |
| z_min = coord[:, 2].min() | |
| z_max = coord[:, 2].max() | |
| z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0] | |
| z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1] | |
| center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_) | |
| # get major global view | |
| major_center = coord[np.random.choice(np.where(center_mask)[0])] | |
| major_view = self.get_view(point, major_center, self.global_view_scale) | |
| major_coord = major_view["coord"] | |
| # get global views: restrict the center of left global view within the major global view | |
| if not self.shared_global_view: | |
| global_views = [ | |
| self.get_view( | |
| point=point, | |
| center=major_coord[np.random.randint(major_coord.shape[0])], | |
| scale=self.global_view_scale, | |
| ) | |
| for _ in range(self.global_view_num - 1) | |
| ] | |
| else: | |
| global_views = [ | |
| {key: value.copy() for key, value in major_view.items()} | |
| for _ in range(self.global_view_num - 1) | |
| ] | |
| global_views = [major_view] + global_views | |
| # get local views: restrict the center of local view within the major global view | |
| cover_mask = np.zeros_like(major_view["index"], dtype=bool) | |
| local_views = [] | |
| for i in range(self.local_view_num): | |
| if sum(~cover_mask) == 0: | |
| # reset cover mask if all points are sampled | |
| cover_mask[:] = False | |
| local_view = self.get_view( | |
| point=data_dict, | |
| center=major_coord[np.random.choice(np.where(~cover_mask)[0])], | |
| scale=self.local_view_scale, | |
| ) | |
| local_views.append(local_view) | |
| cover_mask[np.isin(major_view["index"], local_view["index"])] = True | |
| # augmentation and concat | |
| view_dict = {} | |
| for global_view in global_views: | |
| global_view.pop("index") | |
| global_view = self.global_transform(global_view) | |
| for key in self.view_keys: | |
| if f"global_{key}" in view_dict.keys(): | |
| view_dict[f"global_{key}"].append(global_view[key]) | |
| else: | |
| view_dict[f"global_{key}"] = [global_view[key]] | |
| view_dict["global_offset"] = np.cumsum( | |
| [data.shape[0] for data in view_dict["global_coord"]] | |
| ) | |
| for local_view in local_views: | |
| local_view.pop("index") | |
| local_view = self.local_transform(local_view) | |
| for key in self.view_keys: | |
| if f"local_{key}" in view_dict.keys(): | |
| view_dict[f"local_{key}"].append(local_view[key]) | |
| else: | |
| view_dict[f"local_{key}"] = [local_view[key]] | |
| view_dict["local_offset"] = np.cumsum( | |
| [data.shape[0] for data in view_dict["local_coord"]] | |
| ) | |
| for key in view_dict.keys(): | |
| if "offset" not in key: | |
| view_dict[key] = np.concatenate(view_dict[key], axis=0) | |
| data_dict.update(view_dict) | |
| return data_dict | |
| class InstanceParser(object): | |
| def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1): | |
| self.segment_ignore_index = segment_ignore_index | |
| self.instance_ignore_index = instance_ignore_index | |
| def __call__(self, data_dict): | |
| coord = data_dict["coord"] | |
| segment = data_dict["segment"] | |
| instance = data_dict["instance"] | |
| mask = ~np.in1d(segment, self.segment_ignore_index) | |
| # mapping ignored instance to ignore index | |
| instance[~mask] = self.instance_ignore_index | |
| # reorder left instance | |
| unique, inverse = np.unique(instance[mask], return_inverse=True) | |
| instance_num = len(unique) | |
| instance[mask] = inverse | |
| # init instance information | |
| centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index | |
| bbox = np.ones((instance_num, 8)) * self.instance_ignore_index | |
| vacancy = [ | |
| index for index in self.segment_ignore_index if index >= 0 | |
| ] # vacate class index | |
| for instance_id in range(instance_num): | |
| mask_ = instance == instance_id | |
| coord_ = coord[mask_] | |
| bbox_min = coord_.min(0) | |
| bbox_max = coord_.max(0) | |
| bbox_centroid = coord_.mean(0) | |
| bbox_center = (bbox_max + bbox_min) / 2 | |
| bbox_size = bbox_max - bbox_min | |
| bbox_theta = np.zeros(1, dtype=coord_.dtype) | |
| bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype) | |
| # shift class index to fill vacate class index caused by segment ignore index | |
| bbox_class -= np.greater(bbox_class, vacancy).sum() | |
| centroid[mask_] = bbox_centroid | |
| bbox[instance_id] = np.concatenate( | |
| [bbox_center, bbox_size, bbox_theta, bbox_class] | |
| ) # 3 + 3 + 1 + 1 = 8 | |
| data_dict["instance"] = instance | |
| data_dict["instance_centroid"] = centroid | |
| data_dict["bbox"] = bbox | |
| return data_dict | |
| class Compose(object): | |
| def __init__(self, cfg=None): | |
| self.cfg = cfg if cfg is not None else [] | |
| self.transforms = [] | |
| for t_cfg in self.cfg: | |
| self.transforms.append(TRANSFORMS.build(t_cfg)) | |
| def __call__(self, data_dict): | |
| for t in self.transforms: | |
| data_dict = t(data_dict) | |
| return data_dict | |
| def default(): | |
| config = [ | |
| dict(type="CenterShift", apply_z=True), | |
| dict( | |
| type="GridSample", | |
| # grid_size=0.02, | |
| # grid_size=0.01, | |
| grid_size=0.005, | |
| # grid_size=0.0025, | |
| hash_type="fnv", | |
| mode="train", | |
| return_grid_coord=True, | |
| return_inverse=True, | |
| ), | |
| dict(type="NormalizeColor"), | |
| dict(type="ToTensor"), | |
| dict( | |
| type="Collect", | |
| keys=("coord", "grid_coord", "color", "inverse"), | |
| feat_keys=("coord", "color", "normal"), | |
| ), | |
| ] | |
| return Compose(config) | |