File size: 5,576 Bytes
7b75adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import torch 
import torch.nn as nn 
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen'))
from models import sonata
from utils.misc import smart_load_model

'''
This is the P3-SAM model.
The model is composed of three parts:
1. Sonata: a 3D-CNN model for point cloud feature extraction.
2. SEG1+SEG2: a two-stage multi-head segmentor
3. IoU prediction: an IoU predictor
'''
def build_P3SAM(self):
    ######################## Sonata ########################
    self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata')
    self.mlp = nn.Sequential(
            nn.Linear(1232, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 512),
        )
    self.transform = sonata.transform.default()
    ######################## Sonata ########################

    ######################## SEG1 ########################
    self.seg_mlp_1 = nn.Sequential(
            nn.Linear(512+3+3, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 1),
        )
    self.seg_mlp_2 = nn.Sequential(
            nn.Linear(512+3+3, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 1),
        )
    self.seg_mlp_3 = nn.Sequential(
            nn.Linear(512+3+3, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 1),
        )
    ######################## SEG1 ########################

    ######################## SEG2 ########################
    self.seg_s2_mlp_g = nn.Sequential(
            nn.Linear(512+3+3+3, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
        )
    self.seg_s2_mlp_1 = nn.Sequential(
            nn.Linear(512+3+3+3+256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 1),
        )
    self.seg_s2_mlp_2 = nn.Sequential(
            nn.Linear(512+3+3+3+256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 1),
        )
    self.seg_s2_mlp_3 = nn.Sequential(
            nn.Linear(512+3+3+3+256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 1),
        )
    ######################## SEG2 ########################

    
    self.iou_mlp = nn.Sequential(
            nn.Linear(512+3+3+3+256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
        )
    self.iou_mlp_out = nn.Sequential(
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 256),
            nn.GELU(),
            nn.Linear(256, 3),
        )
    self.iou_criterion = torch.nn.MSELoss()

'''
Load the P3-SAM model from a checkpoint.
If ckpt_path is not None, load the checkpoint from the given path.
If state_dict is not None, load the state_dict from the given state_dict.
If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint.
'''
def load_state_dict(self, 
                    ckpt_path=None, 
                    state_dict=None, 
                    strict=True, 
                    assign=False, 
                    ignore_seg_mlp=False, 
                    ignore_seg_s2_mlp=False, 
                    ignore_iou_mlp=False):
    if ckpt_path is not None:
        state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    elif state_dict is None:
        # download from huggingface
        print(f'trying to download model from huggingface...')
        from huggingface_hub import hf_hub_download
        ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/')
        print(f'download model from huggingface to: {ckpt_path}')
        state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]

    local_state_dict = self.state_dict()
    seen_keys = {k: False for k in local_state_dict.keys()}
    for k, v in state_dict.items():
        if k.startswith("dit."):
            k = k[4:]
        if k in local_state_dict:
            seen_keys[k] = True
            if local_state_dict[k].shape == v.shape:
                local_state_dict[k].copy_(v)
            else:
                print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
        else:
            print(f"unexpected key {k} in loaded state dict")
    seg_mlp_flag = False
    seg_s2_mlp_flag = False
    iou_mlp_flag = False
    for k in seen_keys:
        if not seen_keys[k]:
            if ignore_seg_mlp and 'seg_mlp' in k:
                seg_mlp_flag = True
            elif ignore_seg_s2_mlp and'seg_s2_mlp' in k:
                seg_s2_mlp_flag = True
            elif ignore_iou_mlp and 'iou_mlp' in k:
                iou_mlp_flag = True
            else:
                print(f"missing key {k} in loaded state dict")
    if ignore_seg_mlp and seg_mlp_flag:
        print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict")
    if ignore_seg_s2_mlp and seg_s2_mlp_flag:
        print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict")
    if ignore_iou_mlp and iou_mlp_flag:
        print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict")