Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from .part_encoders import PartEncoder | |
| from ..autoencoders import VolumeDecoderShapeVAE | |
| from ...utils.misc import ( | |
| instantiate_from_config, | |
| instantiate_non_trainable_model, | |
| ) | |
| from .sonata_extractor import SonataFeatureExtractor | |
| from .part_encoders import PartEncoder | |
| def debug_sonata_feat(points, feats): | |
| from sklearn.decomposition import PCA | |
| import numpy as np | |
| import trimesh | |
| import os | |
| point_num = points.shape[0] | |
| feat_save = feats.float().detach().cpu().numpy() | |
| data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True) | |
| pca = PCA(n_components=3) | |
| data_reduced = pca.fit_transform(data_scaled) | |
| data_reduced = (data_reduced - data_reduced.min()) / ( | |
| data_reduced.max() - data_reduced.min() | |
| ) | |
| colors_255 = (data_reduced * 255).astype(np.uint8) | |
| colors_255 = np.concatenate( | |
| [colors_255, np.ones((point_num, 1), dtype=np.uint8) * 255], axis=-1 | |
| ) | |
| pc_save = trimesh.points.PointCloud(points, colors=colors_255) | |
| return pc_save | |
| # pc_save.export(os.path.join("debug", "point_pca.glb")) | |
| class Conditioner(torch.nn.Module): | |
| def __init__( | |
| self, | |
| use_image=False, | |
| use_geo=True, | |
| use_obj=True, | |
| use_seg_feat=False, | |
| geo_cfg=None, | |
| obj_encoder_cfg=None, | |
| seg_feat_cfg=None, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.use_image = use_image | |
| self.use_obj = use_obj | |
| self.use_geo = use_geo | |
| self.use_seg_feat = use_seg_feat | |
| self.geo_cfg = geo_cfg | |
| self.obj_encoder_cfg = obj_encoder_cfg | |
| self.seg_feat_cfg = seg_feat_cfg | |
| if use_geo and geo_cfg is not None: | |
| self.geo_encoder: PartEncoder = instantiate_from_config(geo_cfg) | |
| if hasattr(geo_cfg, "output_dim"): | |
| self.geo_out_proj = torch.nn.Linear(1024 + 512, geo_cfg.output_dim) | |
| if use_obj and obj_encoder_cfg is not None: | |
| self.obj_encoder: VolumeDecoderShapeVAE = instantiate_non_trainable_model( | |
| obj_encoder_cfg | |
| ) | |
| if hasattr(obj_encoder_cfg, "output_dim"): | |
| self.obj_out_proj = torch.nn.Linear( | |
| 1024 + 512, obj_encoder_cfg.output_dim | |
| ) | |
| if use_seg_feat and seg_feat_cfg is not None: | |
| self.seg_feat_encoder: SonataFeatureExtractor = ( | |
| instantiate_non_trainable_model(seg_feat_cfg) | |
| ) | |
| if hasattr(seg_feat_cfg, "output_dim"): | |
| self.seg_feat_outproj = torch.nn.Linear(512, seg_feat_cfg.output_dim) | |
| def forward(self, part_surface_inbbox, object_surface): | |
| bz = part_surface_inbbox.shape[0] | |
| context = {} | |
| # geo_cond | |
| if self.use_geo: | |
| context["geo_cond"], local_pc_infos = self.geo_encoder( | |
| part_surface_inbbox, | |
| object_surface, | |
| return_local_pc_info=True, | |
| ) | |
| # obj cond | |
| if self.use_obj: | |
| with torch.no_grad(): | |
| context["obj_cond"], global_pc_infos = self.obj_encoder.encode_shape( | |
| object_surface, return_pc_info=True | |
| ) | |
| # seg feat cond | |
| if self.use_seg_feat: | |
| # TODO: batchsize must be One | |
| num_parts = part_surface_inbbox.shape[0] | |
| with torch.autocast(device_type="cuda", dtype=torch.float32): | |
| # encode sonata feature | |
| # with torch.cuda.amp.autocast(enabled=False): | |
| with torch.no_grad(): | |
| point, normal = ( | |
| object_surface[:1, ..., :3].float(), | |
| object_surface[:1, ..., 3:6].float(), | |
| ) | |
| point_feat = self.seg_feat_encoder(point, normal) | |
| # local feat | |
| if self.use_obj: | |
| nearest_global_matches = torch.argmin( | |
| torch.cdist(global_pc_infos[0], object_surface[..., :3]), dim=-1 | |
| ) | |
| # global feat | |
| global_point_feats = point_feat.expand(num_parts, -1, -1).gather( | |
| 1, | |
| nearest_global_matches.unsqueeze(-1).expand( | |
| -1, -1, point_feat.size(-1) | |
| ), | |
| ) | |
| context["obj_cond"] = torch.concat( | |
| [context["obj_cond"], global_point_feats], dim=-1 | |
| ).to(dtype=self.obj_out_proj.weight.dtype) | |
| if hasattr(self, "obj_out_proj"): | |
| context["obj_cond"] = self.obj_out_proj( | |
| context["obj_cond"] | |
| ) # .float() | |
| if self.use_geo: | |
| nearest_local_matches = torch.argmin( | |
| torch.cdist(local_pc_infos[0], object_surface[..., :3]), dim=-1 | |
| ) | |
| local_point_feats = point_feat.expand(num_parts, -1, -1).gather( | |
| 1, | |
| nearest_local_matches.unsqueeze(-1).expand( | |
| -1, -1, point_feat.size(-1) | |
| ), | |
| ) | |
| context["geo_cond"] = torch.concat( | |
| [context["geo_cond"], local_point_feats], | |
| dim=-1, | |
| ).to(dtype=self.geo_out_proj.weight.dtype) | |
| if hasattr(self, "geo_out_proj"): | |
| context["geo_cond"] = self.geo_out_proj( | |
| context["geo_cond"] | |
| ) # .float() | |
| return context | |