root
add our app
7b75adb
raw
history blame
5.58 kB
import os
import sys
import torch
import torch.nn as nn
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen'))
from models import sonata
from utils.misc import smart_load_model
'''
This is the P3-SAM model.
The model is composed of three parts:
1. Sonata: a 3D-CNN model for point cloud feature extraction.
2. SEG1+SEG2: a two-stage multi-head segmentor
3. IoU prediction: an IoU predictor
'''
def build_P3SAM(self):
######################## Sonata ########################
self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata')
self.mlp = nn.Sequential(
nn.Linear(1232, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 512),
)
self.transform = sonata.transform.default()
######################## Sonata ########################
######################## SEG1 ########################
self.seg_mlp_1 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
self.seg_mlp_2 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
self.seg_mlp_3 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
######################## SEG1 ########################
######################## SEG2 ########################
self.seg_s2_mlp_g = nn.Sequential(
nn.Linear(512+3+3+3, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
)
self.seg_s2_mlp_1 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
self.seg_s2_mlp_2 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
self.seg_s2_mlp_3 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
######################## SEG2 ########################
self.iou_mlp = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
)
self.iou_mlp_out = nn.Sequential(
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 3),
)
self.iou_criterion = torch.nn.MSELoss()
'''
Load the P3-SAM model from a checkpoint.
If ckpt_path is not None, load the checkpoint from the given path.
If state_dict is not None, load the state_dict from the given state_dict.
If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint.
'''
def load_state_dict(self,
ckpt_path=None,
state_dict=None,
strict=True,
assign=False,
ignore_seg_mlp=False,
ignore_seg_s2_mlp=False,
ignore_iou_mlp=False):
if ckpt_path is not None:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
elif state_dict is None:
# download from huggingface
print(f'trying to download model from huggingface...')
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/')
print(f'download model from huggingface to: {ckpt_path}')
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
local_state_dict = self.state_dict()
seen_keys = {k: False for k in local_state_dict.keys()}
for k, v in state_dict.items():
if k.startswith("dit."):
k = k[4:]
if k in local_state_dict:
seen_keys[k] = True
if local_state_dict[k].shape == v.shape:
local_state_dict[k].copy_(v)
else:
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
else:
print(f"unexpected key {k} in loaded state dict")
seg_mlp_flag = False
seg_s2_mlp_flag = False
iou_mlp_flag = False
for k in seen_keys:
if not seen_keys[k]:
if ignore_seg_mlp and 'seg_mlp' in k:
seg_mlp_flag = True
elif ignore_seg_s2_mlp and'seg_s2_mlp' in k:
seg_s2_mlp_flag = True
elif ignore_iou_mlp and 'iou_mlp' in k:
iou_mlp_flag = True
else:
print(f"missing key {k} in loaded state dict")
if ignore_seg_mlp and seg_mlp_flag:
print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict")
if ignore_seg_s2_mlp and seg_s2_mlp_flag:
print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict")
if ignore_iou_mlp and iou_mlp_flag:
print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict")