""" Point Transformer - V3 Mode2 - Sonata Pointcept detached version Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from packaging import version from huggingface_hub import hf_hub_download, PyTorchModelHubMixin from addict import Dict import torch import torch.nn as nn from torch.nn.init import trunc_normal_ import spconv.pytorch as spconv import torch_scatter from timm.layers import DropPath import json try: import flash_attn except ImportError: flash_attn = None from .structure import Point from .module import PointSequential, PointModule from .utils import offset2bincount MODELS = [ "sonata", "sonata_small", "sonata_linear_prob_head_sc", ] class LayerScale(nn.Module): def __init__( self, dim: int, init_values: float = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class RPE(torch.nn.Module): def __init__(self, patch_size, num_heads): super().__init__() self.patch_size = patch_size self.num_heads = num_heads self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) self.rpe_num = 2 * self.pos_bnd + 1 self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) def forward(self, coord): idx = ( coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + self.pos_bnd # relative position to positive index + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride ) out = self.rpe_table.index_select(0, idx.reshape(-1)) out = out.view(idx.shape + (-1,)).sum(3) out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) return out class SerializedAttention(PointModule): def __init__( self, channels, num_heads, patch_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, order_index=0, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() assert channels % num_heads == 0 self.channels = channels self.num_heads = num_heads self.scale = qk_scale or (channels // num_heads) ** -0.5 self.order_index = order_index self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.enable_rpe = enable_rpe self.enable_flash = enable_flash if enable_flash: assert ( enable_rpe is False ), "Set enable_rpe to False when enable Flash Attention" assert ( upcast_attention is False ), "Set upcast_attention to False when enable Flash Attention" assert ( upcast_softmax is False ), "Set upcast_softmax to False when enable Flash Attention" assert flash_attn is not None, "Make sure flash_attn is installed." self.patch_size = patch_size self.attn_drop = attn_drop else: # when disable flash attention, we still don't want to use mask # consequently, patch size will auto set to the # min number of patch_size_max and number of points self.patch_size_max = patch_size self.patch_size = 0 self.attn_drop = torch.nn.Dropout(attn_drop) self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) self.proj = torch.nn.Linear(channels, channels) self.proj_drop = torch.nn.Dropout(proj_drop) self.softmax = torch.nn.Softmax(dim=-1) self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None @torch.no_grad() def get_rel_pos(self, point, order): K = self.patch_size rel_pos_key = f"rel_pos_{self.order_index}" if rel_pos_key not in point.keys(): grid_coord = point.grid_coord[order] grid_coord = grid_coord.reshape(-1, K, 3) point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) return point[rel_pos_key] @torch.no_grad() def get_padding_and_inverse(self, point): pad_key = "pad" unpad_key = "unpad" cu_seqlens_key = "cu_seqlens_key" if ( pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys() ): offset = point.offset bincount = offset2bincount(offset) bincount_pad = ( torch.div( bincount + self.patch_size - 1, self.patch_size, rounding_mode="trunc", ) * self.patch_size ) # only pad point when num of points larger than patch_size mask_pad = bincount > self.patch_size bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad _offset = nn.functional.pad(offset, (1, 0)) _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) pad = torch.arange(_offset_pad[-1], device=offset.device) unpad = torch.arange(_offset[-1], device=offset.device) cu_seqlens = [] for i in range(len(offset)): unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] if bincount[i] != bincount_pad[i]: pad[ _offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] ] = pad[ _offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] - self.patch_size ] pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] cu_seqlens.append( torch.arange( _offset_pad[i], _offset_pad[i + 1], step=self.patch_size, dtype=torch.int32, device=offset.device, ) ) point[pad_key] = pad point[unpad_key] = unpad point[cu_seqlens_key] = nn.functional.pad( torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] ) return point[pad_key], point[unpad_key], point[cu_seqlens_key] def forward(self, point): if not self.enable_flash: self.patch_size = min( offset2bincount(point.offset).min().tolist(), self.patch_size_max ) H = self.num_heads K = self.patch_size C = self.channels pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) order = point.serialized_order[self.order_index][pad] inverse = unpad[point.serialized_inverse[self.order_index]] # padding and reshape feat and batch for serialized point patch qkv = self.qkv(point.feat)[order] if not self.enable_flash: # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') q, k, v = ( qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) ) # attn if self.upcast_attention: q = q.float() k = k.float() attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) if self.enable_rpe: attn = attn + self.rpe(self.get_rel_pos(point, order)) if self.upcast_softmax: attn = attn.float() attn = self.softmax(attn) attn = self.attn_drop(attn).to(qkv.dtype) feat = (attn @ v).transpose(1, 2).reshape(-1, C) else: feat = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half().reshape(-1, 3, H, C // H), cu_seqlens, max_seqlen=self.patch_size, dropout_p=self.attn_drop if self.training else 0, softmax_scale=self.scale, ).reshape(-1, C) feat = feat.to(qkv.dtype) feat = feat[inverse] # ffn feat = self.proj(feat) feat = self.proj_drop(feat) point.feat = feat return point class MLP(nn.Module): def __init__( self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.fc1 = nn.Linear(in_channels, hidden_channels) self.act = act_layer() self.fc2 = nn.Linear(hidden_channels, out_channels) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, layer_scale=None, norm_layer=nn.LayerNorm, act_layer=nn.GELU, pre_norm=True, order_index=0, cpe_indice_key=None, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() self.channels = channels self.pre_norm = pre_norm self.cpe = PointSequential( spconv.SubMConv3d( channels, channels, kernel_size=3, bias=True, indice_key=cpe_indice_key, ), nn.Linear(channels, channels), norm_layer(channels), ) self.norm1 = PointSequential(norm_layer(channels)) self.ls1 = PointSequential( LayerScale(channels, init_values=layer_scale) if layer_scale is not None else nn.Identity() ) self.attn = SerializedAttention( channels=channels, patch_size=patch_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, order_index=order_index, enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ) self.norm2 = PointSequential(norm_layer(channels)) self.ls2 = PointSequential( LayerScale(channels, init_values=layer_scale) if layer_scale is not None else nn.Identity() ) self.mlp = PointSequential( MLP( in_channels=channels, hidden_channels=int(channels * mlp_ratio), out_channels=channels, act_layer=act_layer, drop=proj_drop, ) ) self.drop_path = PointSequential( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) def forward(self, point: Point): shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) point = self.drop_path(self.ls1(self.attn(point))) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) point = self.drop_path(self.ls2(self.mlp(point))) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm2(point) point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point class GridPooling(PointModule): def __init__( self, in_channels, out_channels, stride=2, norm_layer=None, act_layer=None, reduce="max", shuffle_orders=True, traceable=True, # record parent and cluster ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.stride = stride assert reduce in ["sum", "mean", "min", "max"] self.reduce = reduce self.shuffle_orders = shuffle_orders self.traceable = traceable self.proj = nn.Linear(in_channels, out_channels) if norm_layer is not None: self.norm = PointSequential(norm_layer(out_channels)) if act_layer is not None: self.act = PointSequential(act_layer()) def forward(self, point: Point): if "grid_coord" in point.keys(): grid_coord = point.grid_coord elif {"coord", "grid_size"}.issubset(point.keys()): grid_coord = torch.div( point.coord - point.coord.min(0)[0], point.grid_size, rounding_mode="trunc", ).int() else: raise AssertionError( "[gird_coord] or [coord, grid_size] should be include in the Point" ) grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc") grid_coord = grid_coord | point.batch.view(-1, 1) << 48 grid_coord, cluster, counts = torch.unique( grid_coord, sorted=True, return_inverse=True, return_counts=True, dim=0, ) grid_coord = grid_coord & ((1 << 48) - 1) # indices of point sorted by cluster, for torch_scatter.segment_csr _, indices = torch.sort(cluster) # index pointer for sorted point, for torch_scatter.segment_csr idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] point_dict = Dict( feat=torch_scatter.segment_csr( self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce ), coord=torch_scatter.segment_csr( point.coord[indices], idx_ptr, reduce="mean" ), grid_coord=grid_coord, batch=point.batch[head_indices], ) if "origin_coord" in point.keys(): point_dict["origin_coord"] = torch_scatter.segment_csr( point.origin_coord[indices], idx_ptr, reduce="mean" ) if "condition" in point.keys(): point_dict["condition"] = point.condition if "context" in point.keys(): point_dict["context"] = point.context if "name" in point.keys(): point_dict["name"] = point.name if "split" in point.keys(): point_dict["split"] = point.split if "color" in point.keys(): point_dict["color"] = torch_scatter.segment_csr( point.color[indices], idx_ptr, reduce="mean" ) if "grid_size" in point.keys(): point_dict["grid_size"] = point.grid_size * self.stride if self.traceable: point_dict["pooling_inverse"] = cluster point_dict["pooling_parent"] = point order = point.order point = Point(point_dict) if self.norm is not None: point = self.norm(point) if self.act is not None: point = self.act(point) point.serialization(order=order, shuffle_orders=self.shuffle_orders) point.sparsify() return point class GridUnpooling(PointModule): def __init__( self, in_channels, skip_channels, out_channels, norm_layer=None, act_layer=None, traceable=False, # record parent and cluster ): super().__init__() self.proj = PointSequential(nn.Linear(in_channels, out_channels)) self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) if norm_layer is not None: self.proj.add(norm_layer(out_channels)) self.proj_skip.add(norm_layer(out_channels)) if act_layer is not None: self.proj.add(act_layer()) self.proj_skip.add(act_layer()) self.traceable = traceable def forward(self, point): assert "pooling_parent" in point.keys() assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pooling_inverse feat = point.feat parent = self.proj_skip(parent) parent.feat = parent.feat + self.proj(point).feat[inverse] parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat) if self.traceable: point.feat = feat parent["unpooling_parent"] = point return parent class Embedding(PointModule): def __init__( self, in_channels, embed_channels, norm_layer=None, act_layer=None, mask_token=False, ): super().__init__() self.in_channels = in_channels self.embed_channels = embed_channels self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels)) if norm_layer is not None: self.stem.add(norm_layer(embed_channels), name="norm") if act_layer is not None: self.stem.add(act_layer(), name="act") if mask_token: self.mask_token = nn.Parameter(torch.zeros(1, embed_channels)) else: self.mask_token = None def forward(self, point: Point): point = self.stem(point) if "mask" in point.keys(): point.feat = torch.where( point.mask.unsqueeze(-1), self.mask_token.to(point.feat.dtype), point.feat, ) return point class PointTransformerV3(PointModule, PyTorchModelHubMixin): def __init__( self, in_channels=6, order=("z", "z-trans"), stride=(2, 2, 2, 2), enc_depths=(3, 3, 3, 12, 3), enc_channels=(48, 96, 192, 384, 512), enc_num_head=(3, 6, 12, 24, 32), enc_patch_size=(1024, 1024, 1024, 1024, 1024), dec_depths=(3, 3, 3, 3), dec_channels=(96, 96, 192, 384), dec_num_head=(6, 6, 12, 32), dec_patch_size=(1024, 1024, 1024, 1024), mlp_ratio=4, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.3, layer_scale=None, pre_norm=True, shuffle_orders=True, enable_rpe=False, enable_flash=True, upcast_attention=False, upcast_softmax=False, traceable=False, mask_token=False, enc_mode=False, freeze_encoder=False, ): super().__init__() self.num_stages = len(enc_depths) self.order = [order] if isinstance(order, str) else order self.enc_mode = enc_mode self.shuffle_orders = shuffle_orders self.freeze_encoder = freeze_encoder assert self.num_stages == len(stride) + 1 assert self.num_stages == len(enc_depths) assert self.num_stages == len(enc_channels) assert self.num_stages == len(enc_num_head) assert self.num_stages == len(enc_patch_size) assert self.enc_mode or self.num_stages == len(dec_depths) + 1 assert self.enc_mode or self.num_stages == len(dec_channels) + 1 assert self.enc_mode or self.num_stages == len(dec_num_head) + 1 assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1 print(f"flash attention: {enable_flash}") # normalization layer ln_layer = nn.LayerNorm # activation layers act_layer = nn.GELU self.embedding = Embedding( in_channels=in_channels, embed_channels=enc_channels[0], norm_layer=ln_layer, act_layer=act_layer, mask_token=mask_token, ) # encoder enc_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) ] self.enc = PointSequential() for s in range(self.num_stages): enc_drop_path_ = enc_drop_path[ sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) ] enc = PointSequential() if s > 0: enc.add( GridPooling( in_channels=enc_channels[s - 1], out_channels=enc_channels[s], stride=stride[s - 1], norm_layer=ln_layer, act_layer=act_layer, ), name="down", ) for i in range(enc_depths[s]): enc.add( Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=enc_drop_path_[i], layer_scale=layer_scale, norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) if len(enc) != 0: self.enc.add(module=enc, name=f"enc{s}") # decoder if not self.enc_mode: dec_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) ] self.dec = PointSequential() dec_channels = list(dec_channels) + [enc_channels[-1]] for s in reversed(range(self.num_stages - 1)): dec_drop_path_ = dec_drop_path[ sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) ] dec_drop_path_.reverse() dec = PointSequential() dec.add( GridUnpooling( in_channels=dec_channels[s + 1], skip_channels=enc_channels[s], out_channels=dec_channels[s], norm_layer=ln_layer, act_layer=act_layer, traceable=traceable, ), name="up", ) for i in range(dec_depths[s]): dec.add( Block( channels=dec_channels[s], num_heads=dec_num_head[s], patch_size=dec_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=dec_drop_path_[i], layer_scale=layer_scale, norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) self.dec.add(module=dec, name=f"dec{s}") if self.freeze_encoder: for p in self.embedding.parameters(): p.requires_grad = False for p in self.enc.parameters(): p.requires_grad = False self.apply(self._init_weights) @staticmethod def _init_weights(module): if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, spconv.SubMConv3d): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, data_dict): point = Point(data_dict) point = self.embedding(point) point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) point.sparsify() point = self.enc(point) if not self.enc_mode: point = self.dec(point) return point def load( name: str = "sonata", repo_id="facebook/sonata", download_root: str = None, custom_config: dict = None, ckpt_only: bool = False, ): if name in MODELS: print(f"Loading checkpoint from HuggingFace: {name} ...") ckpt_path = hf_hub_download( repo_id=repo_id, filename=f"{name}.pth", repo_type="model", revision="main", local_dir=download_root or os.path.expanduser("~/.cache/sonata/ckpt"), ) elif os.path.isfile(name): print(f"Loading checkpoint in local path: {name} ...") ckpt_path = name else: raise RuntimeError(f"Model {name} not found; available models = {MODELS}") if version.parse(torch.__version__) >= version.parse("2.4"): ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) else: ckpt = torch.load(ckpt_path, map_location="cpu") if custom_config is not None: for key, value in custom_config.items(): ckpt["config"][key] = value if ckpt_only: return ckpt # 关闭flash attention # ckpt["config"]['enable_flash'] = False model = PointTransformerV3(**ckpt["config"]) model.load_state_dict(ckpt["state_dict"]) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") return model def load_by_config(config_path: str): with open(config_path, "r") as f: config = json.load(f) model = PointTransformerV3(**config) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") return model