Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Point Transformer - V3 Mode2 - Sonata | |
| Pointcept detached version | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from packaging import version | |
| from huggingface_hub import hf_hub_download, PyTorchModelHubMixin | |
| from addict import Dict | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.init import trunc_normal_ | |
| import spconv.pytorch as spconv | |
| import torch_scatter | |
| from timm.layers import DropPath | |
| import json | |
| try: | |
| import flash_attn | |
| except ImportError: | |
| flash_attn = None | |
| from .structure import Point | |
| from .module import PointSequential, PointModule | |
| from .utils import offset2bincount | |
| MODELS = [ | |
| "sonata", | |
| "sonata_small", | |
| "sonata_linear_prob_head_sc", | |
| ] | |
| class LayerScale(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| init_values: float = 1e-5, | |
| inplace: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.inplace = inplace | |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
| class RPE(torch.nn.Module): | |
| def __init__(self, patch_size, num_heads): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.num_heads = num_heads | |
| self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) | |
| self.rpe_num = 2 * self.pos_bnd + 1 | |
| self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) | |
| torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) | |
| def forward(self, coord): | |
| idx = ( | |
| coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd | |
| + self.pos_bnd # relative position to positive index | |
| + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride | |
| ) | |
| out = self.rpe_table.index_select(0, idx.reshape(-1)) | |
| out = out.view(idx.shape + (-1,)).sum(3) | |
| out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) | |
| return out | |
| class SerializedAttention(PointModule): | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads, | |
| patch_size, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| order_index=0, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=True, | |
| upcast_softmax=True, | |
| ): | |
| super().__init__() | |
| assert channels % num_heads == 0 | |
| self.channels = channels | |
| self.num_heads = num_heads | |
| self.scale = qk_scale or (channels // num_heads) ** -0.5 | |
| self.order_index = order_index | |
| self.upcast_attention = upcast_attention | |
| self.upcast_softmax = upcast_softmax | |
| self.enable_rpe = enable_rpe | |
| self.enable_flash = enable_flash | |
| if enable_flash: | |
| assert ( | |
| enable_rpe is False | |
| ), "Set enable_rpe to False when enable Flash Attention" | |
| assert ( | |
| upcast_attention is False | |
| ), "Set upcast_attention to False when enable Flash Attention" | |
| assert ( | |
| upcast_softmax is False | |
| ), "Set upcast_softmax to False when enable Flash Attention" | |
| assert flash_attn is not None, "Make sure flash_attn is installed." | |
| self.patch_size = patch_size | |
| self.attn_drop = attn_drop | |
| else: | |
| # when disable flash attention, we still don't want to use mask | |
| # consequently, patch size will auto set to the | |
| # min number of patch_size_max and number of points | |
| self.patch_size_max = patch_size | |
| self.patch_size = 0 | |
| self.attn_drop = torch.nn.Dropout(attn_drop) | |
| self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) | |
| self.proj = torch.nn.Linear(channels, channels) | |
| self.proj_drop = torch.nn.Dropout(proj_drop) | |
| self.softmax = torch.nn.Softmax(dim=-1) | |
| self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None | |
| def get_rel_pos(self, point, order): | |
| K = self.patch_size | |
| rel_pos_key = f"rel_pos_{self.order_index}" | |
| if rel_pos_key not in point.keys(): | |
| grid_coord = point.grid_coord[order] | |
| grid_coord = grid_coord.reshape(-1, K, 3) | |
| point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) | |
| return point[rel_pos_key] | |
| def get_padding_and_inverse(self, point): | |
| pad_key = "pad" | |
| unpad_key = "unpad" | |
| cu_seqlens_key = "cu_seqlens_key" | |
| if ( | |
| pad_key not in point.keys() | |
| or unpad_key not in point.keys() | |
| or cu_seqlens_key not in point.keys() | |
| ): | |
| offset = point.offset | |
| bincount = offset2bincount(offset) | |
| bincount_pad = ( | |
| torch.div( | |
| bincount + self.patch_size - 1, | |
| self.patch_size, | |
| rounding_mode="trunc", | |
| ) | |
| * self.patch_size | |
| ) | |
| # only pad point when num of points larger than patch_size | |
| mask_pad = bincount > self.patch_size | |
| bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad | |
| _offset = nn.functional.pad(offset, (1, 0)) | |
| _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) | |
| pad = torch.arange(_offset_pad[-1], device=offset.device) | |
| unpad = torch.arange(_offset[-1], device=offset.device) | |
| cu_seqlens = [] | |
| for i in range(len(offset)): | |
| unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] | |
| if bincount[i] != bincount_pad[i]: | |
| pad[ | |
| _offset_pad[i + 1] | |
| - self.patch_size | |
| + (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
| ] = pad[ | |
| _offset_pad[i + 1] | |
| - 2 * self.patch_size | |
| + (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
| - self.patch_size | |
| ] | |
| pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] | |
| cu_seqlens.append( | |
| torch.arange( | |
| _offset_pad[i], | |
| _offset_pad[i + 1], | |
| step=self.patch_size, | |
| dtype=torch.int32, | |
| device=offset.device, | |
| ) | |
| ) | |
| point[pad_key] = pad | |
| point[unpad_key] = unpad | |
| point[cu_seqlens_key] = nn.functional.pad( | |
| torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] | |
| ) | |
| return point[pad_key], point[unpad_key], point[cu_seqlens_key] | |
| def forward(self, point): | |
| if not self.enable_flash: | |
| self.patch_size = min( | |
| offset2bincount(point.offset).min().tolist(), self.patch_size_max | |
| ) | |
| H = self.num_heads | |
| K = self.patch_size | |
| C = self.channels | |
| pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) | |
| order = point.serialized_order[self.order_index][pad] | |
| inverse = unpad[point.serialized_inverse[self.order_index]] | |
| # padding and reshape feat and batch for serialized point patch | |
| qkv = self.qkv(point.feat)[order] | |
| if not self.enable_flash: | |
| # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') | |
| q, k, v = ( | |
| qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) | |
| ) | |
| # attn | |
| if self.upcast_attention: | |
| q = q.float() | |
| k = k.float() | |
| attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) | |
| if self.enable_rpe: | |
| attn = attn + self.rpe(self.get_rel_pos(point, order)) | |
| if self.upcast_softmax: | |
| attn = attn.float() | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn).to(qkv.dtype) | |
| feat = (attn @ v).transpose(1, 2).reshape(-1, C) | |
| else: | |
| feat = flash_attn.flash_attn_varlen_qkvpacked_func( | |
| qkv.half().reshape(-1, 3, H, C // H), | |
| cu_seqlens, | |
| max_seqlen=self.patch_size, | |
| dropout_p=self.attn_drop if self.training else 0, | |
| softmax_scale=self.scale, | |
| ).reshape(-1, C) | |
| feat = feat.to(qkv.dtype) | |
| feat = feat[inverse] | |
| # ffn | |
| feat = self.proj(feat) | |
| feat = self.proj_drop(feat) | |
| point.feat = feat | |
| return point | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| hidden_channels=None, | |
| out_channels=None, | |
| act_layer=nn.GELU, | |
| drop=0.0, | |
| ): | |
| super().__init__() | |
| out_channels = out_channels or in_channels | |
| hidden_channels = hidden_channels or in_channels | |
| self.fc1 = nn.Linear(in_channels, hidden_channels) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_channels, out_channels) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class Block(PointModule): | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads, | |
| patch_size=48, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| drop_path=0.0, | |
| layer_scale=None, | |
| norm_layer=nn.LayerNorm, | |
| act_layer=nn.GELU, | |
| pre_norm=True, | |
| order_index=0, | |
| cpe_indice_key=None, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=True, | |
| upcast_softmax=True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.pre_norm = pre_norm | |
| self.cpe = PointSequential( | |
| spconv.SubMConv3d( | |
| channels, | |
| channels, | |
| kernel_size=3, | |
| bias=True, | |
| indice_key=cpe_indice_key, | |
| ), | |
| nn.Linear(channels, channels), | |
| norm_layer(channels), | |
| ) | |
| self.norm1 = PointSequential(norm_layer(channels)) | |
| self.ls1 = PointSequential( | |
| LayerScale(channels, init_values=layer_scale) | |
| if layer_scale is not None | |
| else nn.Identity() | |
| ) | |
| self.attn = SerializedAttention( | |
| channels=channels, | |
| patch_size=patch_size, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| order_index=order_index, | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ) | |
| self.norm2 = PointSequential(norm_layer(channels)) | |
| self.ls2 = PointSequential( | |
| LayerScale(channels, init_values=layer_scale) | |
| if layer_scale is not None | |
| else nn.Identity() | |
| ) | |
| self.mlp = PointSequential( | |
| MLP( | |
| in_channels=channels, | |
| hidden_channels=int(channels * mlp_ratio), | |
| out_channels=channels, | |
| act_layer=act_layer, | |
| drop=proj_drop, | |
| ) | |
| ) | |
| self.drop_path = PointSequential( | |
| DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| ) | |
| def forward(self, point: Point): | |
| shortcut = point.feat | |
| point = self.cpe(point) | |
| point.feat = shortcut + point.feat | |
| shortcut = point.feat | |
| if self.pre_norm: | |
| point = self.norm1(point) | |
| point = self.drop_path(self.ls1(self.attn(point))) | |
| point.feat = shortcut + point.feat | |
| if not self.pre_norm: | |
| point = self.norm1(point) | |
| shortcut = point.feat | |
| if self.pre_norm: | |
| point = self.norm2(point) | |
| point = self.drop_path(self.ls2(self.mlp(point))) | |
| point.feat = shortcut + point.feat | |
| if not self.pre_norm: | |
| point = self.norm2(point) | |
| point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) | |
| return point | |
| class GridPooling(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| stride=2, | |
| norm_layer=None, | |
| act_layer=None, | |
| reduce="max", | |
| shuffle_orders=True, | |
| traceable=True, # record parent and cluster | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.stride = stride | |
| assert reduce in ["sum", "mean", "min", "max"] | |
| self.reduce = reduce | |
| self.shuffle_orders = shuffle_orders | |
| self.traceable = traceable | |
| self.proj = nn.Linear(in_channels, out_channels) | |
| if norm_layer is not None: | |
| self.norm = PointSequential(norm_layer(out_channels)) | |
| if act_layer is not None: | |
| self.act = PointSequential(act_layer()) | |
| def forward(self, point: Point): | |
| if "grid_coord" in point.keys(): | |
| grid_coord = point.grid_coord | |
| elif {"coord", "grid_size"}.issubset(point.keys()): | |
| grid_coord = torch.div( | |
| point.coord - point.coord.min(0)[0], | |
| point.grid_size, | |
| rounding_mode="trunc", | |
| ).int() | |
| else: | |
| raise AssertionError( | |
| "[gird_coord] or [coord, grid_size] should be include in the Point" | |
| ) | |
| grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc") | |
| grid_coord = grid_coord | point.batch.view(-1, 1) << 48 | |
| grid_coord, cluster, counts = torch.unique( | |
| grid_coord, | |
| sorted=True, | |
| return_inverse=True, | |
| return_counts=True, | |
| dim=0, | |
| ) | |
| grid_coord = grid_coord & ((1 << 48) - 1) | |
| # indices of point sorted by cluster, for torch_scatter.segment_csr | |
| _, indices = torch.sort(cluster) | |
| # index pointer for sorted point, for torch_scatter.segment_csr | |
| idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) | |
| # head_indices of each cluster, for reduce attr e.g. code, batch | |
| head_indices = indices[idx_ptr[:-1]] | |
| point_dict = Dict( | |
| feat=torch_scatter.segment_csr( | |
| self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce | |
| ), | |
| coord=torch_scatter.segment_csr( | |
| point.coord[indices], idx_ptr, reduce="mean" | |
| ), | |
| grid_coord=grid_coord, | |
| batch=point.batch[head_indices], | |
| ) | |
| if "origin_coord" in point.keys(): | |
| point_dict["origin_coord"] = torch_scatter.segment_csr( | |
| point.origin_coord[indices], idx_ptr, reduce="mean" | |
| ) | |
| if "condition" in point.keys(): | |
| point_dict["condition"] = point.condition | |
| if "context" in point.keys(): | |
| point_dict["context"] = point.context | |
| if "name" in point.keys(): | |
| point_dict["name"] = point.name | |
| if "split" in point.keys(): | |
| point_dict["split"] = point.split | |
| if "color" in point.keys(): | |
| point_dict["color"] = torch_scatter.segment_csr( | |
| point.color[indices], idx_ptr, reduce="mean" | |
| ) | |
| if "grid_size" in point.keys(): | |
| point_dict["grid_size"] = point.grid_size * self.stride | |
| if self.traceable: | |
| point_dict["pooling_inverse"] = cluster | |
| point_dict["pooling_parent"] = point | |
| order = point.order | |
| point = Point(point_dict) | |
| if self.norm is not None: | |
| point = self.norm(point) | |
| if self.act is not None: | |
| point = self.act(point) | |
| point.serialization(order=order, shuffle_orders=self.shuffle_orders) | |
| point.sparsify() | |
| return point | |
| class GridUnpooling(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| skip_channels, | |
| out_channels, | |
| norm_layer=None, | |
| act_layer=None, | |
| traceable=False, # record parent and cluster | |
| ): | |
| super().__init__() | |
| self.proj = PointSequential(nn.Linear(in_channels, out_channels)) | |
| self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) | |
| if norm_layer is not None: | |
| self.proj.add(norm_layer(out_channels)) | |
| self.proj_skip.add(norm_layer(out_channels)) | |
| if act_layer is not None: | |
| self.proj.add(act_layer()) | |
| self.proj_skip.add(act_layer()) | |
| self.traceable = traceable | |
| def forward(self, point): | |
| assert "pooling_parent" in point.keys() | |
| assert "pooling_inverse" in point.keys() | |
| parent = point.pop("pooling_parent") | |
| inverse = point.pooling_inverse | |
| feat = point.feat | |
| parent = self.proj_skip(parent) | |
| parent.feat = parent.feat + self.proj(point).feat[inverse] | |
| parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat) | |
| if self.traceable: | |
| point.feat = feat | |
| parent["unpooling_parent"] = point | |
| return parent | |
| class Embedding(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| embed_channels, | |
| norm_layer=None, | |
| act_layer=None, | |
| mask_token=False, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.embed_channels = embed_channels | |
| self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels)) | |
| if norm_layer is not None: | |
| self.stem.add(norm_layer(embed_channels), name="norm") | |
| if act_layer is not None: | |
| self.stem.add(act_layer(), name="act") | |
| if mask_token: | |
| self.mask_token = nn.Parameter(torch.zeros(1, embed_channels)) | |
| else: | |
| self.mask_token = None | |
| def forward(self, point: Point): | |
| point = self.stem(point) | |
| if "mask" in point.keys(): | |
| point.feat = torch.where( | |
| point.mask.unsqueeze(-1), | |
| self.mask_token.to(point.feat.dtype), | |
| point.feat, | |
| ) | |
| return point | |
| class PointTransformerV3(PointModule, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| in_channels=6, | |
| order=("z", "z-trans"), | |
| stride=(2, 2, 2, 2), | |
| enc_depths=(3, 3, 3, 12, 3), | |
| enc_channels=(48, 96, 192, 384, 512), | |
| enc_num_head=(3, 6, 12, 24, 32), | |
| enc_patch_size=(1024, 1024, 1024, 1024, 1024), | |
| dec_depths=(3, 3, 3, 3), | |
| dec_channels=(96, 96, 192, 384), | |
| dec_num_head=(6, 6, 12, 32), | |
| dec_patch_size=(1024, 1024, 1024, 1024), | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| drop_path=0.3, | |
| layer_scale=None, | |
| pre_norm=True, | |
| shuffle_orders=True, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=False, | |
| upcast_softmax=False, | |
| traceable=False, | |
| mask_token=False, | |
| enc_mode=False, | |
| freeze_encoder=False, | |
| ): | |
| super().__init__() | |
| self.num_stages = len(enc_depths) | |
| self.order = [order] if isinstance(order, str) else order | |
| self.enc_mode = enc_mode | |
| self.shuffle_orders = shuffle_orders | |
| self.freeze_encoder = freeze_encoder | |
| assert self.num_stages == len(stride) + 1 | |
| assert self.num_stages == len(enc_depths) | |
| assert self.num_stages == len(enc_channels) | |
| assert self.num_stages == len(enc_num_head) | |
| assert self.num_stages == len(enc_patch_size) | |
| assert self.enc_mode or self.num_stages == len(dec_depths) + 1 | |
| assert self.enc_mode or self.num_stages == len(dec_channels) + 1 | |
| assert self.enc_mode or self.num_stages == len(dec_num_head) + 1 | |
| assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1 | |
| print(f"flash attention: {enable_flash}") | |
| # normalization layer | |
| ln_layer = nn.LayerNorm | |
| # activation layers | |
| act_layer = nn.GELU | |
| self.embedding = Embedding( | |
| in_channels=in_channels, | |
| embed_channels=enc_channels[0], | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| mask_token=mask_token, | |
| ) | |
| # encoder | |
| enc_drop_path = [ | |
| x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) | |
| ] | |
| self.enc = PointSequential() | |
| for s in range(self.num_stages): | |
| enc_drop_path_ = enc_drop_path[ | |
| sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) | |
| ] | |
| enc = PointSequential() | |
| if s > 0: | |
| enc.add( | |
| GridPooling( | |
| in_channels=enc_channels[s - 1], | |
| out_channels=enc_channels[s], | |
| stride=stride[s - 1], | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| ), | |
| name="down", | |
| ) | |
| for i in range(enc_depths[s]): | |
| enc.add( | |
| Block( | |
| channels=enc_channels[s], | |
| num_heads=enc_num_head[s], | |
| patch_size=enc_patch_size[s], | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| drop_path=enc_drop_path_[i], | |
| layer_scale=layer_scale, | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| pre_norm=pre_norm, | |
| order_index=i % len(self.order), | |
| cpe_indice_key=f"stage{s}", | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ), | |
| name=f"block{i}", | |
| ) | |
| if len(enc) != 0: | |
| self.enc.add(module=enc, name=f"enc{s}") | |
| # decoder | |
| if not self.enc_mode: | |
| dec_drop_path = [ | |
| x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) | |
| ] | |
| self.dec = PointSequential() | |
| dec_channels = list(dec_channels) + [enc_channels[-1]] | |
| for s in reversed(range(self.num_stages - 1)): | |
| dec_drop_path_ = dec_drop_path[ | |
| sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) | |
| ] | |
| dec_drop_path_.reverse() | |
| dec = PointSequential() | |
| dec.add( | |
| GridUnpooling( | |
| in_channels=dec_channels[s + 1], | |
| skip_channels=enc_channels[s], | |
| out_channels=dec_channels[s], | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| traceable=traceable, | |
| ), | |
| name="up", | |
| ) | |
| for i in range(dec_depths[s]): | |
| dec.add( | |
| Block( | |
| channels=dec_channels[s], | |
| num_heads=dec_num_head[s], | |
| patch_size=dec_patch_size[s], | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| drop_path=dec_drop_path_[i], | |
| layer_scale=layer_scale, | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| pre_norm=pre_norm, | |
| order_index=i % len(self.order), | |
| cpe_indice_key=f"stage{s}", | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ), | |
| name=f"block{i}", | |
| ) | |
| self.dec.add(module=dec, name=f"dec{s}") | |
| if self.freeze_encoder: | |
| for p in self.embedding.parameters(): | |
| p.requires_grad = False | |
| for p in self.enc.parameters(): | |
| p.requires_grad = False | |
| self.apply(self._init_weights) | |
| def _init_weights(module): | |
| if isinstance(module, nn.Linear): | |
| trunc_normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, spconv.SubMConv3d): | |
| trunc_normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, data_dict): | |
| point = Point(data_dict) | |
| point = self.embedding(point) | |
| point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) | |
| point.sparsify() | |
| point = self.enc(point) | |
| if not self.enc_mode: | |
| point = self.dec(point) | |
| return point | |
| def load( | |
| name: str = "sonata", | |
| repo_id="facebook/sonata", | |
| download_root: str = None, | |
| custom_config: dict = None, | |
| ckpt_only: bool = False, | |
| ): | |
| if name in MODELS: | |
| print(f"Loading checkpoint from HuggingFace: {name} ...") | |
| ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=f"{name}.pth", | |
| repo_type="model", | |
| revision="main", | |
| local_dir=download_root or os.path.expanduser("~/.cache/sonata/ckpt"), | |
| ) | |
| elif os.path.isfile(name): | |
| print(f"Loading checkpoint in local path: {name} ...") | |
| ckpt_path = name | |
| else: | |
| raise RuntimeError(f"Model {name} not found; available models = {MODELS}") | |
| if version.parse(torch.__version__) >= version.parse("2.4"): | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| else: | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| if custom_config is not None: | |
| for key, value in custom_config.items(): | |
| ckpt["config"][key] = value | |
| if ckpt_only: | |
| return ckpt | |
| # 关闭flash attention | |
| # ckpt["config"]['enable_flash'] = False | |
| model = PointTransformerV3(**ckpt["config"]) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") | |
| return model | |
| def load_by_config(config_path: str): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| model = PointTransformerV3(**config) | |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") | |
| return model | |