""" Point Modules Pointcept detached version Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import torch.nn as nn import spconv.pytorch as spconv from collections import OrderedDict from .structure import Point class PointModule(nn.Module): r"""PointModule placeholder, all module subclass from this will take Point in PointSequential. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class PointSequential(PointModule): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ def __init__(self, *args, **kwargs): super().__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) for name, module in kwargs.items(): if sys.version_info < (3, 6): raise ValueError("kwargs only supported in py36+") if name in self._modules: raise ValueError("name exists.") self.add_module(name, module) def __getitem__(self, idx): if not (-len(self) <= idx < len(self)): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) it = iter(self._modules.values()) for i in range(idx): next(it) return next(it) def __len__(self): return len(self._modules) def add(self, module, name=None): if name is None: name = str(len(self._modules)) if name in self._modules: raise KeyError("name exists") self.add_module(name, module) def forward(self, input): for k, module in self._modules.items(): # Point module if isinstance(module, PointModule): input = module(input) # Spconv module elif spconv.modules.is_spconv_module(module): if isinstance(input, Point): input.sparse_conv_feat = module(input.sparse_conv_feat) input.feat = input.sparse_conv_feat.features else: input = module(input) # PyTorch module else: if isinstance(input, Point): input.feat = module(input.feat) if "sparse_conv_feat" in input.keys(): input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( input.feat ) elif isinstance(input, spconv.SparseConvTensor): if input.indices.shape[0] != 0: input = input.replace_feature(module(input.features)) else: input = module(input) return input