Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,576 Bytes
7b75adb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import sys
import torch
import torch.nn as nn
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen'))
from models import sonata
from utils.misc import smart_load_model
'''
This is the P3-SAM model.
The model is composed of three parts:
1. Sonata: a 3D-CNN model for point cloud feature extraction.
2. SEG1+SEG2: a two-stage multi-head segmentor
3. IoU prediction: an IoU predictor
'''
def build_P3SAM(self):
######################## Sonata ########################
self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata')
self.mlp = nn.Sequential(
nn.Linear(1232, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 512),
)
self.transform = sonata.transform.default()
######################## Sonata ########################
######################## SEG1 ########################
self.seg_mlp_1 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
self.seg_mlp_2 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
self.seg_mlp_3 = nn.Sequential(
nn.Linear(512+3+3, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 1),
)
######################## SEG1 ########################
######################## SEG2 ########################
self.seg_s2_mlp_g = nn.Sequential(
nn.Linear(512+3+3+3, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
)
self.seg_s2_mlp_1 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
self.seg_s2_mlp_2 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
self.seg_s2_mlp_3 = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 1),
)
######################## SEG2 ########################
self.iou_mlp = nn.Sequential(
nn.Linear(512+3+3+3+256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
)
self.iou_mlp_out = nn.Sequential(
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 256),
nn.GELU(),
nn.Linear(256, 3),
)
self.iou_criterion = torch.nn.MSELoss()
'''
Load the P3-SAM model from a checkpoint.
If ckpt_path is not None, load the checkpoint from the given path.
If state_dict is not None, load the state_dict from the given state_dict.
If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint.
'''
def load_state_dict(self,
ckpt_path=None,
state_dict=None,
strict=True,
assign=False,
ignore_seg_mlp=False,
ignore_seg_s2_mlp=False,
ignore_iou_mlp=False):
if ckpt_path is not None:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
elif state_dict is None:
# download from huggingface
print(f'trying to download model from huggingface...')
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/')
print(f'download model from huggingface to: {ckpt_path}')
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
local_state_dict = self.state_dict()
seen_keys = {k: False for k in local_state_dict.keys()}
for k, v in state_dict.items():
if k.startswith("dit."):
k = k[4:]
if k in local_state_dict:
seen_keys[k] = True
if local_state_dict[k].shape == v.shape:
local_state_dict[k].copy_(v)
else:
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
else:
print(f"unexpected key {k} in loaded state dict")
seg_mlp_flag = False
seg_s2_mlp_flag = False
iou_mlp_flag = False
for k in seen_keys:
if not seen_keys[k]:
if ignore_seg_mlp and 'seg_mlp' in k:
seg_mlp_flag = True
elif ignore_seg_s2_mlp and'seg_s2_mlp' in k:
seg_s2_mlp_flag = True
elif ignore_iou_mlp and 'iou_mlp' in k:
iou_mlp_flag = True
else:
print(f"missing key {k} in loaded state dict")
if ignore_seg_mlp and seg_mlp_flag:
print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict")
if ignore_seg_s2_mlp and seg_s2_mlp_flag:
print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict")
if ignore_iou_mlp and iou_mlp_flag:
print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict")
|