Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torchsparse.nn as spnn | |
| from torchsparse.point_tensor import PointTensor | |
| from lib.spvcnn_utils import * | |
| __all__ = ['SPVCNN_CLASSIFICATION'] | |
| class BasicConvolutionBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d(inc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=stride), | |
| spnn.BatchNorm(outc), | |
| spnn.ReLU(True)) | |
| def forward(self, x): | |
| out = self.net(x) | |
| return out | |
| class BasicDeconvolutionBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d(inc, | |
| outc, | |
| kernel_size=ks, | |
| stride=stride, | |
| transpose=True), | |
| spnn.BatchNorm(outc), | |
| spnn.ReLU(True)) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d(inc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=stride), spnn.BatchNorm(outc), | |
| spnn.ReLU(True), | |
| spnn.Conv3d(outc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=1), | |
| spnn.BatchNorm(outc) | |
| ) | |
| self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ | |
| nn.Sequential( | |
| spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), | |
| spnn.BatchNorm(outc) | |
| ) | |
| self.relu = spnn.ReLU(True) | |
| def forward(self, x): | |
| out = self.relu(self.net(x) + self.downsample(x)) | |
| return out | |
| class SPVCNN_CLASSIFICATION(nn.Module): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| cr = kwargs.get('cr', 1.0) | |
| cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] | |
| cs = [int(cr * x) for x in cs] | |
| if 'pres' in kwargs and 'vres' in kwargs: | |
| self.pres = kwargs['pres'] | |
| self.vres = kwargs['vres'] | |
| self.stem = nn.Sequential( | |
| spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1), | |
| spnn.BatchNorm(cs[0]), | |
| spnn.ReLU(True), | |
| spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), | |
| spnn.BatchNorm(cs[0]), | |
| spnn.ReLU(True)) | |
| self.stage1 = nn.Sequential( | |
| BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), | |
| ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), | |
| ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), | |
| ) | |
| self.stage2 = nn.Sequential( | |
| BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), | |
| ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), | |
| ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), | |
| ) | |
| self.stage3 = nn.Sequential( | |
| BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), | |
| ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), | |
| ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), | |
| ) | |
| self.stage4 = nn.Sequential( | |
| BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), | |
| ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), | |
| ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), | |
| ) | |
| self.avg_pool = spnn.GlobalAveragePooling() | |
| self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes'])) | |
| self.point_transforms = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(cs[0], cs[4]), | |
| nn.BatchNorm1d(cs[4]), | |
| nn.ReLU(True), | |
| ), | |
| ]) | |
| self.weight_initialization() | |
| self.dropout = nn.Dropout(0.3, True) | |
| def weight_initialization(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| # x: SparseTensor z: PointTensor | |
| z = PointTensor(x.F, x.C.float()) | |
| x0 = initial_voxelize(z, self.pres, self.vres) | |
| x0 = self.stem(x0) | |
| z0 = voxel_to_point(x0, z, nearest=False) | |
| z0.F = z0.F | |
| x1 = point_to_voxel(x0, z0) | |
| x1 = self.stage1(x1) | |
| x2 = self.stage2(x1) | |
| x3 = self.stage3(x2) | |
| x4 = self.stage4(x3) | |
| z1 = voxel_to_point(x4, z0) | |
| z1.F = z1.F + self.point_transforms[0](z0.F) | |
| y1 = point_to_voxel(x4, z1) | |
| pool = self.avg_pool(y1) | |
| out = self.classifier(pool) | |
| return out | |