|
|
""" |
|
|
Point Transformer - V3 Mode2 - Sonata & Concerto |
|
|
Pointcept detached version |
|
|
|
|
|
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
|
|
Please cite our work if the code is helpful to you. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from packaging import version |
|
|
from huggingface_hub import hf_hub_download, PyTorchModelHubMixin |
|
|
from addict import Dict |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.init import trunc_normal_ |
|
|
import spconv.pytorch as spconv |
|
|
import torch_scatter |
|
|
from timm.layers import DropPath |
|
|
|
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
except ImportError: |
|
|
flash_attn = None |
|
|
|
|
|
from .structure import Point |
|
|
from .module import PointSequential, PointModule |
|
|
from .utils import offset2bincount |
|
|
|
|
|
MODELS = [ |
|
|
"sonata", |
|
|
"concerto_large", |
|
|
"concerto_base", |
|
|
"concerto_small", |
|
|
"concerto_large_linear_prob_head_sc", |
|
|
] |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
init_values: float = 1e-5, |
|
|
inplace: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.inplace = inplace |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
|
|
|
class RPE(torch.nn.Module): |
|
|
def __init__(self, patch_size, num_heads): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.num_heads = num_heads |
|
|
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) |
|
|
self.rpe_num = 2 * self.pos_bnd + 1 |
|
|
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) |
|
|
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) |
|
|
|
|
|
def forward(self, coord): |
|
|
idx = ( |
|
|
coord.clamp(-self.pos_bnd, self.pos_bnd) |
|
|
+ self.pos_bnd |
|
|
+ torch.arange(3, device=coord.device) * self.rpe_num |
|
|
) |
|
|
out = self.rpe_table.index_select(0, idx.reshape(-1)) |
|
|
out = out.view(idx.shape + (-1,)).sum(3) |
|
|
out = out.permute(0, 3, 1, 2) |
|
|
return out |
|
|
|
|
|
|
|
|
class SerializedAttention(PointModule): |
|
|
def __init__( |
|
|
self, |
|
|
channels, |
|
|
num_heads, |
|
|
patch_size, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
order_index=0, |
|
|
enable_rpe=False, |
|
|
enable_flash=True, |
|
|
upcast_attention=True, |
|
|
upcast_softmax=True, |
|
|
): |
|
|
super().__init__() |
|
|
assert channels % num_heads == 0 |
|
|
self.channels = channels |
|
|
self.num_heads = num_heads |
|
|
self.scale = qk_scale or (channels // num_heads) ** -0.5 |
|
|
self.order_index = order_index |
|
|
self.upcast_attention = upcast_attention |
|
|
self.upcast_softmax = upcast_softmax |
|
|
self.enable_rpe = enable_rpe |
|
|
self.enable_flash = enable_flash |
|
|
if enable_flash: |
|
|
assert ( |
|
|
enable_rpe is False |
|
|
), "Set enable_rpe to False when enable Flash Attention" |
|
|
assert ( |
|
|
upcast_attention is False |
|
|
), "Set upcast_attention to False when enable Flash Attention" |
|
|
assert ( |
|
|
upcast_softmax is False |
|
|
), "Set upcast_softmax to False when enable Flash Attention" |
|
|
assert flash_attn is not None, "Make sure flash_attn is installed." |
|
|
self.patch_size = patch_size |
|
|
self.attn_drop = attn_drop |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
self.patch_size_max = patch_size |
|
|
self.patch_size = 0 |
|
|
self.attn_drop = torch.nn.Dropout(attn_drop) |
|
|
|
|
|
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) |
|
|
self.proj = torch.nn.Linear(channels, channels) |
|
|
self.proj_drop = torch.nn.Dropout(proj_drop) |
|
|
self.softmax = torch.nn.Softmax(dim=-1) |
|
|
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_rel_pos(self, point, order): |
|
|
K = self.patch_size |
|
|
rel_pos_key = f"rel_pos_{self.order_index}" |
|
|
if rel_pos_key not in point.keys(): |
|
|
grid_coord = point.grid_coord[order] |
|
|
grid_coord = grid_coord.reshape(-1, K, 3) |
|
|
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) |
|
|
return point[rel_pos_key] |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_padding_and_inverse(self, point): |
|
|
pad_key = "pad" |
|
|
unpad_key = "unpad" |
|
|
cu_seqlens_key = "cu_seqlens_key" |
|
|
if ( |
|
|
pad_key not in point.keys() |
|
|
or unpad_key not in point.keys() |
|
|
or cu_seqlens_key not in point.keys() |
|
|
): |
|
|
offset = point.offset |
|
|
bincount = offset2bincount(offset) |
|
|
bincount_pad = ( |
|
|
torch.div( |
|
|
bincount + self.patch_size - 1, |
|
|
self.patch_size, |
|
|
rounding_mode="trunc", |
|
|
) |
|
|
* self.patch_size |
|
|
) |
|
|
|
|
|
mask_pad = bincount > self.patch_size |
|
|
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad |
|
|
_offset = nn.functional.pad(offset, (1, 0)) |
|
|
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) |
|
|
pad = torch.arange(_offset_pad[-1], device=offset.device) |
|
|
unpad = torch.arange(_offset[-1], device=offset.device) |
|
|
cu_seqlens = [] |
|
|
for i in range(len(offset)): |
|
|
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] |
|
|
if bincount[i] != bincount_pad[i]: |
|
|
pad[ |
|
|
_offset_pad[i + 1] |
|
|
- self.patch_size |
|
|
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] |
|
|
] = pad[ |
|
|
_offset_pad[i + 1] |
|
|
- 2 * self.patch_size |
|
|
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] |
|
|
- self.patch_size |
|
|
] |
|
|
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] |
|
|
cu_seqlens.append( |
|
|
torch.arange( |
|
|
_offset_pad[i], |
|
|
_offset_pad[i + 1], |
|
|
step=self.patch_size, |
|
|
dtype=torch.int32, |
|
|
device=offset.device, |
|
|
) |
|
|
) |
|
|
point[pad_key] = pad |
|
|
point[unpad_key] = unpad |
|
|
point[cu_seqlens_key] = nn.functional.pad( |
|
|
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] |
|
|
) |
|
|
return point[pad_key], point[unpad_key], point[cu_seqlens_key] |
|
|
|
|
|
def forward(self, point): |
|
|
if not self.enable_flash: |
|
|
self.patch_size = min( |
|
|
offset2bincount(point.offset).min().tolist(), self.patch_size_max |
|
|
) |
|
|
|
|
|
H = self.num_heads |
|
|
K = self.patch_size |
|
|
C = self.channels |
|
|
|
|
|
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) |
|
|
|
|
|
order = point.serialized_order[self.order_index][pad] |
|
|
inverse = unpad[point.serialized_inverse[self.order_index]] |
|
|
|
|
|
|
|
|
qkv = self.qkv(point.feat)[order] |
|
|
|
|
|
if not self.enable_flash: |
|
|
|
|
|
q, k, v = ( |
|
|
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) |
|
|
) |
|
|
|
|
|
if self.upcast_attention: |
|
|
q = q.float() |
|
|
k = k.float() |
|
|
attn = (q * self.scale) @ k.transpose(-2, -1) |
|
|
if self.enable_rpe: |
|
|
attn = attn + self.rpe(self.get_rel_pos(point, order)) |
|
|
if self.upcast_softmax: |
|
|
attn = attn.float() |
|
|
attn = self.softmax(attn) |
|
|
attn = self.attn_drop(attn).to(qkv.dtype) |
|
|
feat = (attn @ v).transpose(1, 2).reshape(-1, C) |
|
|
else: |
|
|
feat = flash_attn.flash_attn_varlen_qkvpacked_func( |
|
|
qkv.half().reshape(-1, 3, H, C // H), |
|
|
cu_seqlens, |
|
|
max_seqlen=self.patch_size, |
|
|
dropout_p=self.attn_drop if self.training else 0, |
|
|
softmax_scale=self.scale, |
|
|
).reshape(-1, C) |
|
|
feat = feat.to(qkv.dtype) |
|
|
feat = feat[inverse] |
|
|
|
|
|
|
|
|
feat = self.proj(feat) |
|
|
feat = self.proj_drop(feat) |
|
|
point.feat = feat |
|
|
return point |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
hidden_channels=None, |
|
|
out_channels=None, |
|
|
act_layer=nn.GELU, |
|
|
drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
out_channels = out_channels or in_channels |
|
|
hidden_channels = hidden_channels or in_channels |
|
|
self.fc1 = nn.Linear(in_channels, hidden_channels) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_channels, out_channels) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(PointModule): |
|
|
def __init__( |
|
|
self, |
|
|
channels, |
|
|
num_heads, |
|
|
patch_size=48, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
drop_path=0.0, |
|
|
layer_scale=None, |
|
|
norm_layer=nn.LayerNorm, |
|
|
act_layer=nn.GELU, |
|
|
pre_norm=True, |
|
|
order_index=0, |
|
|
cpe_indice_key=None, |
|
|
enable_rpe=False, |
|
|
enable_flash=True, |
|
|
upcast_attention=True, |
|
|
upcast_softmax=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
self.pre_norm = pre_norm |
|
|
|
|
|
self.cpe = PointSequential( |
|
|
spconv.SubMConv3d( |
|
|
channels, |
|
|
channels, |
|
|
kernel_size=3, |
|
|
bias=True, |
|
|
indice_key=cpe_indice_key, |
|
|
), |
|
|
nn.Linear(channels, channels), |
|
|
norm_layer(channels), |
|
|
) |
|
|
|
|
|
self.norm1 = PointSequential(norm_layer(channels)) |
|
|
self.ls1 = PointSequential( |
|
|
LayerScale(channels, init_values=layer_scale) |
|
|
if layer_scale is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
self.attn = SerializedAttention( |
|
|
channels=channels, |
|
|
patch_size=patch_size, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=proj_drop, |
|
|
order_index=order_index, |
|
|
enable_rpe=enable_rpe, |
|
|
enable_flash=enable_flash, |
|
|
upcast_attention=upcast_attention, |
|
|
upcast_softmax=upcast_softmax, |
|
|
) |
|
|
self.norm2 = PointSequential(norm_layer(channels)) |
|
|
self.ls2 = PointSequential( |
|
|
LayerScale(channels, init_values=layer_scale) |
|
|
if layer_scale is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
self.mlp = PointSequential( |
|
|
MLP( |
|
|
in_channels=channels, |
|
|
hidden_channels=int(channels * mlp_ratio), |
|
|
out_channels=channels, |
|
|
act_layer=act_layer, |
|
|
drop=proj_drop, |
|
|
) |
|
|
) |
|
|
self.drop_path = PointSequential( |
|
|
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, point: Point): |
|
|
shortcut = point.feat |
|
|
point = self.cpe(point) |
|
|
point.feat = shortcut + point.feat |
|
|
shortcut = point.feat |
|
|
if self.pre_norm: |
|
|
point = self.norm1(point) |
|
|
point = self.drop_path(self.ls1(self.attn(point))) |
|
|
point.feat = shortcut + point.feat |
|
|
if not self.pre_norm: |
|
|
point = self.norm1(point) |
|
|
|
|
|
shortcut = point.feat |
|
|
if self.pre_norm: |
|
|
point = self.norm2(point) |
|
|
point = self.drop_path(self.ls2(self.mlp(point))) |
|
|
point.feat = shortcut + point.feat |
|
|
if not self.pre_norm: |
|
|
point = self.norm2(point) |
|
|
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) |
|
|
return point |
|
|
|
|
|
|
|
|
class GridPooling(PointModule): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
stride=2, |
|
|
norm_layer=None, |
|
|
act_layer=None, |
|
|
reduce="max", |
|
|
shuffle_orders=True, |
|
|
traceable=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
|
|
|
self.stride = stride |
|
|
assert reduce in ["sum", "mean", "min", "max"] |
|
|
self.reduce = reduce |
|
|
self.shuffle_orders = shuffle_orders |
|
|
self.traceable = traceable |
|
|
|
|
|
self.proj = nn.Linear(in_channels, out_channels) |
|
|
if norm_layer is not None: |
|
|
self.norm = PointSequential(norm_layer(out_channels)) |
|
|
if act_layer is not None: |
|
|
self.act = PointSequential(act_layer()) |
|
|
|
|
|
def forward(self, point: Point): |
|
|
if "grid_coord" in point.keys(): |
|
|
grid_coord = point.grid_coord |
|
|
elif {"coord", "grid_size"}.issubset(point.keys()): |
|
|
grid_coord = torch.div( |
|
|
point.coord - point.coord.min(0)[0], |
|
|
point.grid_size, |
|
|
rounding_mode="trunc", |
|
|
).int() |
|
|
else: |
|
|
raise AssertionError( |
|
|
"[gird_coord] or [coord, grid_size] should be include in the Point" |
|
|
) |
|
|
grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc") |
|
|
grid_coord = grid_coord | point.batch.view(-1, 1) << 48 |
|
|
grid_coord, cluster, counts = torch.unique( |
|
|
grid_coord, |
|
|
sorted=True, |
|
|
return_inverse=True, |
|
|
return_counts=True, |
|
|
dim=0, |
|
|
) |
|
|
grid_coord = grid_coord & ((1 << 48) - 1) |
|
|
|
|
|
_, indices = torch.sort(cluster) |
|
|
|
|
|
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) |
|
|
|
|
|
head_indices = indices[idx_ptr[:-1]] |
|
|
point_dict = Dict( |
|
|
feat=torch_scatter.segment_csr( |
|
|
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce |
|
|
), |
|
|
coord=torch_scatter.segment_csr( |
|
|
point.coord[indices], idx_ptr, reduce="mean" |
|
|
), |
|
|
grid_coord=grid_coord, |
|
|
batch=point.batch[head_indices], |
|
|
) |
|
|
if "origin_coord" in point.keys(): |
|
|
point_dict["origin_coord"] = torch_scatter.segment_csr( |
|
|
point.origin_coord[indices], idx_ptr, reduce="mean" |
|
|
) |
|
|
if "condition" in point.keys(): |
|
|
point_dict["condition"] = point.condition |
|
|
if "context" in point.keys(): |
|
|
point_dict["context"] = point.context |
|
|
if "name" in point.keys(): |
|
|
point_dict["name"] = point.name |
|
|
if "split" in point.keys(): |
|
|
point_dict["split"] = point.split |
|
|
if "color" in point.keys(): |
|
|
point_dict["color"] = torch_scatter.segment_csr( |
|
|
point.color[indices], idx_ptr, reduce="mean" |
|
|
) |
|
|
if "grid_size" in point.keys(): |
|
|
point_dict["grid_size"] = point.grid_size * self.stride |
|
|
|
|
|
if self.traceable: |
|
|
point_dict["pooling_inverse"] = cluster |
|
|
point_dict["pooling_parent"] = point |
|
|
point_dict["idx_ptr"] = idx_ptr |
|
|
order = point.order |
|
|
point = Point(point_dict) |
|
|
if self.norm is not None: |
|
|
point = self.norm(point) |
|
|
if self.act is not None: |
|
|
point = self.act(point) |
|
|
point.serialization(order=order, shuffle_orders=self.shuffle_orders) |
|
|
point.sparsify() |
|
|
return point |
|
|
|
|
|
|
|
|
class GridUnpooling(PointModule): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
skip_channels, |
|
|
out_channels, |
|
|
norm_layer=None, |
|
|
act_layer=None, |
|
|
traceable=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.proj = PointSequential(nn.Linear(in_channels, out_channels)) |
|
|
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) |
|
|
|
|
|
if norm_layer is not None: |
|
|
self.proj.add(norm_layer(out_channels)) |
|
|
self.proj_skip.add(norm_layer(out_channels)) |
|
|
|
|
|
if act_layer is not None: |
|
|
self.proj.add(act_layer()) |
|
|
self.proj_skip.add(act_layer()) |
|
|
|
|
|
self.traceable = traceable |
|
|
|
|
|
def forward(self, point): |
|
|
assert "pooling_parent" in point.keys() |
|
|
assert "pooling_inverse" in point.keys() |
|
|
parent = point.pop("pooling_parent") |
|
|
inverse = point.pooling_inverse |
|
|
feat = point.feat |
|
|
|
|
|
parent = self.proj_skip(parent) |
|
|
parent.feat = parent.feat + self.proj(point).feat[inverse] |
|
|
parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat) |
|
|
|
|
|
if self.traceable: |
|
|
point.feat = feat |
|
|
parent["unpooling_parent"] = point |
|
|
return parent |
|
|
|
|
|
|
|
|
class Embedding(PointModule): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
embed_channels, |
|
|
norm_layer=None, |
|
|
act_layer=None, |
|
|
mask_token=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
self.embed_channels = embed_channels |
|
|
|
|
|
self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels)) |
|
|
if norm_layer is not None: |
|
|
self.stem.add(norm_layer(embed_channels), name="norm") |
|
|
if act_layer is not None: |
|
|
self.stem.add(act_layer(), name="act") |
|
|
|
|
|
if mask_token: |
|
|
self.mask_token = nn.Parameter(torch.zeros(1, embed_channels)) |
|
|
else: |
|
|
self.mask_token = None |
|
|
|
|
|
def forward(self, point: Point): |
|
|
point = self.stem(point) |
|
|
if "mask" in point.keys(): |
|
|
point.feat = torch.where( |
|
|
point.mask.unsqueeze(-1), |
|
|
self.mask_token.to(point.feat.dtype), |
|
|
point.feat, |
|
|
) |
|
|
return point |
|
|
|
|
|
|
|
|
class PointTransformerV3(PointModule, PyTorchModelHubMixin): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=6, |
|
|
order=("z", "z-trans"), |
|
|
stride=(2, 2, 2, 2), |
|
|
enc_depths=(3, 3, 3, 12, 3), |
|
|
enc_channels=(48, 96, 192, 384, 512), |
|
|
enc_num_head=(3, 6, 12, 24, 32), |
|
|
enc_patch_size=(1024, 1024, 1024, 1024, 1024), |
|
|
dec_depths=(3, 3, 3, 3), |
|
|
dec_channels=(96, 96, 192, 384), |
|
|
dec_num_head=(6, 6, 12, 32), |
|
|
dec_patch_size=(1024, 1024, 1024, 1024), |
|
|
mlp_ratio=4, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
drop_path=0.3, |
|
|
layer_scale=None, |
|
|
pre_norm=True, |
|
|
shuffle_orders=True, |
|
|
enable_rpe=False, |
|
|
enable_flash=True, |
|
|
upcast_attention=False, |
|
|
upcast_softmax=False, |
|
|
traceable=False, |
|
|
mask_token=False, |
|
|
enc_mode=False, |
|
|
freeze_encoder=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_stages = len(enc_depths) |
|
|
self.order = [order] if isinstance(order, str) else order |
|
|
self.enc_mode = enc_mode |
|
|
self.shuffle_orders = shuffle_orders |
|
|
self.freeze_encoder = freeze_encoder |
|
|
|
|
|
assert self.num_stages == len(stride) + 1 |
|
|
assert self.num_stages == len(enc_depths) |
|
|
assert self.num_stages == len(enc_channels) |
|
|
assert self.num_stages == len(enc_num_head) |
|
|
assert self.num_stages == len(enc_patch_size) |
|
|
assert self.enc_mode or self.num_stages == len(dec_depths) + 1 |
|
|
assert self.enc_mode or self.num_stages == len(dec_channels) + 1 |
|
|
assert self.enc_mode or self.num_stages == len(dec_num_head) + 1 |
|
|
assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1 |
|
|
|
|
|
|
|
|
ln_layer = nn.LayerNorm |
|
|
|
|
|
act_layer = nn.GELU |
|
|
|
|
|
self.embedding = Embedding( |
|
|
in_channels=in_channels, |
|
|
embed_channels=enc_channels[0], |
|
|
norm_layer=ln_layer, |
|
|
act_layer=act_layer, |
|
|
mask_token=mask_token, |
|
|
) |
|
|
|
|
|
|
|
|
enc_drop_path = [ |
|
|
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) |
|
|
] |
|
|
self.enc = PointSequential() |
|
|
for s in range(self.num_stages): |
|
|
enc_drop_path_ = enc_drop_path[ |
|
|
sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) |
|
|
] |
|
|
enc = PointSequential() |
|
|
if s > 0: |
|
|
enc.add( |
|
|
GridPooling( |
|
|
in_channels=enc_channels[s - 1], |
|
|
out_channels=enc_channels[s], |
|
|
stride=stride[s - 1], |
|
|
norm_layer=ln_layer, |
|
|
act_layer=act_layer, |
|
|
), |
|
|
name="down", |
|
|
) |
|
|
for i in range(enc_depths[s]): |
|
|
enc.add( |
|
|
Block( |
|
|
channels=enc_channels[s], |
|
|
num_heads=enc_num_head[s], |
|
|
patch_size=enc_patch_size[s], |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=proj_drop, |
|
|
drop_path=enc_drop_path_[i], |
|
|
layer_scale=layer_scale, |
|
|
norm_layer=ln_layer, |
|
|
act_layer=act_layer, |
|
|
pre_norm=pre_norm, |
|
|
order_index=i % len(self.order), |
|
|
cpe_indice_key=f"stage{s}", |
|
|
enable_rpe=enable_rpe, |
|
|
enable_flash=enable_flash, |
|
|
upcast_attention=upcast_attention, |
|
|
upcast_softmax=upcast_softmax, |
|
|
), |
|
|
name=f"block{i}", |
|
|
) |
|
|
if len(enc) != 0: |
|
|
self.enc.add(module=enc, name=f"enc{s}") |
|
|
|
|
|
|
|
|
if not self.enc_mode: |
|
|
dec_drop_path = [ |
|
|
x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) |
|
|
] |
|
|
self.dec = PointSequential() |
|
|
dec_channels = list(dec_channels) + [enc_channels[-1]] |
|
|
for s in reversed(range(self.num_stages - 1)): |
|
|
dec_drop_path_ = dec_drop_path[ |
|
|
sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) |
|
|
] |
|
|
dec_drop_path_.reverse() |
|
|
dec = PointSequential() |
|
|
dec.add( |
|
|
GridUnpooling( |
|
|
in_channels=dec_channels[s + 1], |
|
|
skip_channels=enc_channels[s], |
|
|
out_channels=dec_channels[s], |
|
|
norm_layer=ln_layer, |
|
|
act_layer=act_layer, |
|
|
traceable=traceable, |
|
|
), |
|
|
name="up", |
|
|
) |
|
|
for i in range(dec_depths[s]): |
|
|
dec.add( |
|
|
Block( |
|
|
channels=dec_channels[s], |
|
|
num_heads=dec_num_head[s], |
|
|
patch_size=dec_patch_size[s], |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=proj_drop, |
|
|
drop_path=dec_drop_path_[i], |
|
|
layer_scale=layer_scale, |
|
|
norm_layer=ln_layer, |
|
|
act_layer=act_layer, |
|
|
pre_norm=pre_norm, |
|
|
order_index=i % len(self.order), |
|
|
cpe_indice_key=f"stage{s}", |
|
|
enable_rpe=enable_rpe, |
|
|
enable_flash=enable_flash, |
|
|
upcast_attention=upcast_attention, |
|
|
upcast_softmax=upcast_softmax, |
|
|
), |
|
|
name=f"block{i}", |
|
|
) |
|
|
self.dec.add(module=dec, name=f"dec{s}") |
|
|
if self.freeze_encoder: |
|
|
for p in self.embedding.parameters(): |
|
|
p.requires_grad = False |
|
|
for p in self.enc.parameters(): |
|
|
p.requires_grad = False |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
@staticmethod |
|
|
def _init_weights(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, spconv.SubMConv3d): |
|
|
trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward(self, data_dict): |
|
|
point = Point(data_dict) |
|
|
point = self.embedding(point) |
|
|
|
|
|
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) |
|
|
point.sparsify() |
|
|
|
|
|
point = self.enc(point) |
|
|
if not self.enc_mode: |
|
|
point = self.dec(point) |
|
|
return point |
|
|
|
|
|
|
|
|
def load( |
|
|
name: str = "concerto_large", |
|
|
repo_id="Pointcept/Concerto", |
|
|
download_root: str = None, |
|
|
custom_config: dict = None, |
|
|
ckpt_only: bool = False, |
|
|
): |
|
|
if name in MODELS: |
|
|
print(f"Loading checkpoint from HuggingFace: {name} ...") |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=f"{name}.pth", |
|
|
repo_type="model", |
|
|
revision="main", |
|
|
local_dir=download_root or os.path.expanduser("~/.cache/concerto/ckpt"), |
|
|
) |
|
|
elif os.path.isfile(name): |
|
|
print(f"Loading checkpoint in local path: {name} ...") |
|
|
ckpt_path = name |
|
|
else: |
|
|
raise RuntimeError(f"Model {name} not found; available models = {MODELS}") |
|
|
|
|
|
if version.parse(torch.__version__) >= version.parse("2.4"): |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
|
|
else: |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
if custom_config is not None: |
|
|
for key, value in custom_config.items(): |
|
|
ckpt["config"][key] = value |
|
|
|
|
|
if ckpt_only: |
|
|
return ckpt |
|
|
|
|
|
model = PointTransformerV3(**ckpt["config"]) |
|
|
model.load_state_dict(ckpt["state_dict"]) |
|
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Model params: {n_parameters / 1e6:.2f}M") |
|
|
return model |
|
|
|