root
add our app
7b75adb
raw
history blame
6.76 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .part_encoders import PartEncoder
from ..autoencoders import VolumeDecoderShapeVAE
from ...utils.misc import (
instantiate_from_config,
instantiate_non_trainable_model,
)
from .sonata_extractor import SonataFeatureExtractor
from .part_encoders import PartEncoder
def debug_sonata_feat(points, feats):
from sklearn.decomposition import PCA
import numpy as np
import trimesh
import os
point_num = points.shape[0]
feat_save = feats.float().detach().cpu().numpy()
data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
pca = PCA(n_components=3)
data_reduced = pca.fit_transform(data_scaled)
data_reduced = (data_reduced - data_reduced.min()) / (
data_reduced.max() - data_reduced.min()
)
colors_255 = (data_reduced * 255).astype(np.uint8)
colors_255 = np.concatenate(
[colors_255, np.ones((point_num, 1), dtype=np.uint8) * 255], axis=-1
)
pc_save = trimesh.points.PointCloud(points, colors=colors_255)
return pc_save
# pc_save.export(os.path.join("debug", "point_pca.glb"))
class Conditioner(torch.nn.Module):
def __init__(
self,
use_image=False,
use_geo=True,
use_obj=True,
use_seg_feat=False,
geo_cfg=None,
obj_encoder_cfg=None,
seg_feat_cfg=None,
**kwargs
):
super().__init__()
self.use_image = use_image
self.use_obj = use_obj
self.use_geo = use_geo
self.use_seg_feat = use_seg_feat
self.geo_cfg = geo_cfg
self.obj_encoder_cfg = obj_encoder_cfg
self.seg_feat_cfg = seg_feat_cfg
if use_geo and geo_cfg is not None:
self.geo_encoder: PartEncoder = instantiate_from_config(geo_cfg)
if hasattr(geo_cfg, "output_dim"):
self.geo_out_proj = torch.nn.Linear(1024 + 512, geo_cfg.output_dim)
if use_obj and obj_encoder_cfg is not None:
self.obj_encoder: VolumeDecoderShapeVAE = instantiate_non_trainable_model(
obj_encoder_cfg
)
if hasattr(obj_encoder_cfg, "output_dim"):
self.obj_out_proj = torch.nn.Linear(
1024 + 512, obj_encoder_cfg.output_dim
)
if use_seg_feat and seg_feat_cfg is not None:
self.seg_feat_encoder: SonataFeatureExtractor = (
instantiate_non_trainable_model(seg_feat_cfg)
)
if hasattr(seg_feat_cfg, "output_dim"):
self.seg_feat_outproj = torch.nn.Linear(512, seg_feat_cfg.output_dim)
def forward(self, part_surface_inbbox, object_surface):
bz = part_surface_inbbox.shape[0]
context = {}
# geo_cond
if self.use_geo:
context["geo_cond"], local_pc_infos = self.geo_encoder(
part_surface_inbbox,
object_surface,
return_local_pc_info=True,
)
# obj cond
if self.use_obj:
with torch.no_grad():
context["obj_cond"], global_pc_infos = self.obj_encoder.encode_shape(
object_surface, return_pc_info=True
)
# seg feat cond
if self.use_seg_feat:
# TODO: batchsize must be One
num_parts = part_surface_inbbox.shape[0]
with torch.autocast(device_type="cuda", dtype=torch.float32):
# encode sonata feature
# with torch.cuda.amp.autocast(enabled=False):
with torch.no_grad():
point, normal = (
object_surface[:1, ..., :3].float(),
object_surface[:1, ..., 3:6].float(),
)
point_feat = self.seg_feat_encoder(point, normal)
# local feat
if self.use_obj:
nearest_global_matches = torch.argmin(
torch.cdist(global_pc_infos[0], object_surface[..., :3]), dim=-1
)
# global feat
global_point_feats = point_feat.expand(num_parts, -1, -1).gather(
1,
nearest_global_matches.unsqueeze(-1).expand(
-1, -1, point_feat.size(-1)
),
)
context["obj_cond"] = torch.concat(
[context["obj_cond"], global_point_feats], dim=-1
).to(dtype=self.obj_out_proj.weight.dtype)
if hasattr(self, "obj_out_proj"):
context["obj_cond"] = self.obj_out_proj(
context["obj_cond"]
) # .float()
if self.use_geo:
nearest_local_matches = torch.argmin(
torch.cdist(local_pc_infos[0], object_surface[..., :3]), dim=-1
)
local_point_feats = point_feat.expand(num_parts, -1, -1).gather(
1,
nearest_local_matches.unsqueeze(-1).expand(
-1, -1, point_feat.size(-1)
),
)
context["geo_cond"] = torch.concat(
[context["geo_cond"], local_point_feats],
dim=-1,
).to(dtype=self.geo_out_proj.weight.dtype)
if hasattr(self, "geo_out_proj"):
context["geo_cond"] = self.geo_out_proj(
context["geo_cond"]
) # .float()
return context