Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Point Modules | |
| Pointcept detached version | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import sys | |
| import torch.nn as nn | |
| import spconv.pytorch as spconv | |
| from collections import OrderedDict | |
| from .structure import Point | |
| class PointModule(nn.Module): | |
| r"""PointModule | |
| placeholder, all module subclass from this will take Point in PointSequential. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| class PointSequential(PointModule): | |
| r"""A sequential container. | |
| Modules will be added to it in the order they are passed in the constructor. | |
| Alternatively, an ordered dict of modules can also be passed in. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| if len(args) == 1 and isinstance(args[0], OrderedDict): | |
| for key, module in args[0].items(): | |
| self.add_module(key, module) | |
| else: | |
| for idx, module in enumerate(args): | |
| self.add_module(str(idx), module) | |
| for name, module in kwargs.items(): | |
| if sys.version_info < (3, 6): | |
| raise ValueError("kwargs only supported in py36+") | |
| if name in self._modules: | |
| raise ValueError("name exists.") | |
| self.add_module(name, module) | |
| def __getitem__(self, idx): | |
| if not (-len(self) <= idx < len(self)): | |
| raise IndexError("index {} is out of range".format(idx)) | |
| if idx < 0: | |
| idx += len(self) | |
| it = iter(self._modules.values()) | |
| for i in range(idx): | |
| next(it) | |
| return next(it) | |
| def __len__(self): | |
| return len(self._modules) | |
| def add(self, module, name=None): | |
| if name is None: | |
| name = str(len(self._modules)) | |
| if name in self._modules: | |
| raise KeyError("name exists") | |
| self.add_module(name, module) | |
| def forward(self, input): | |
| for k, module in self._modules.items(): | |
| # Point module | |
| if isinstance(module, PointModule): | |
| input = module(input) | |
| # Spconv module | |
| elif spconv.modules.is_spconv_module(module): | |
| if isinstance(input, Point): | |
| input.sparse_conv_feat = module(input.sparse_conv_feat) | |
| input.feat = input.sparse_conv_feat.features | |
| else: | |
| input = module(input) | |
| # PyTorch module | |
| else: | |
| if isinstance(input, Point): | |
| input.feat = module(input.feat) | |
| if "sparse_conv_feat" in input.keys(): | |
| input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( | |
| input.feat | |
| ) | |
| elif isinstance(input, spconv.SparseConvTensor): | |
| if input.indices.shape[0] != 0: | |
| input = input.replace_feature(module(input.features)) | |
| else: | |
| input = module(input) | |
| return input | |