Spaces:
Runtime error
Runtime error
File size: 5,115 Bytes
7b75adb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# @lint-ignore-every LICENSELINT
# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
from typing import Optional, Union
class KeyLUT:
def __init__(self):
r256 = torch.arange(256, dtype=torch.int64)
r512 = torch.arange(512, dtype=torch.int64)
zero = torch.zeros(256, dtype=torch.int64)
device = torch.device("cpu")
self._encode = {
device: (
self.xyz2key(r256, zero, zero, 8),
self.xyz2key(zero, r256, zero, 8),
self.xyz2key(zero, zero, r256, 8),
)
}
self._decode = {device: self.key2xyz(r512, 9)}
def encode_lut(self, device=torch.device("cpu")):
if device not in self._encode:
cpu = torch.device("cpu")
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
return self._encode[device]
def decode_lut(self, device=torch.device("cpu")):
if device not in self._decode:
cpu = torch.device("cpu")
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
return self._decode[device]
def xyz2key(self, x, y, z, depth):
key = torch.zeros_like(x)
for i in range(depth):
mask = 1 << i
key = (
key
| ((x & mask) << (2 * i + 2))
| ((y & mask) << (2 * i + 1))
| ((z & mask) << (2 * i + 0))
)
return key
def key2xyz(self, key, depth):
x = torch.zeros_like(key)
y = torch.zeros_like(key)
z = torch.zeros_like(key)
for i in range(depth):
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
return x, y, z
_key_lut = KeyLUT()
def xyz2key(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
b: Optional[Union[torch.Tensor, int]] = None,
depth: int = 16,
):
"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
based on pre-computed look up tables. The speed of this function is much
faster than the method based on for-loop.
Args:
x (torch.Tensor): The x coordinate.
y (torch.Tensor): The y coordinate.
z (torch.Tensor): The z coordinate.
b (torch.Tensor or int): The batch index of the coordinates, and should be
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
"""
EX, EY, EZ = _key_lut.encode_lut(x.device)
x, y, z = x.long(), y.long(), z.long()
mask = 255 if depth > 8 else (1 << depth) - 1
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
if depth > 8:
mask = (1 << (depth - 8)) - 1
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
key = key16 << 24 | key
if b is not None:
b = b.long()
key = b << 48 | key
return key
def key2xyz(key: torch.Tensor, depth: int = 16):
r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
and the batch index based on pre-computed look up tables.
Args:
key (torch.Tensor): The shuffled key.
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
"""
DX, DY, DZ = _key_lut.decode_lut(key.device)
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
b = key >> 48
key = key & ((1 << 48) - 1)
n = (depth + 2) // 3
for i in range(n):
k = key >> (i * 9) & 511
x = x | (DX[k] << (i * 3))
y = y | (DY[k] << (i * 3))
z = z | (DZ[k] << (i * 3))
return x, y, z, b
|