Spaces:
Paused
Paused
| import random | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import re | |
| from einops import rearrange | |
| from dataclasses import dataclass | |
| from torchvision import transforms | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from transformers.utils import ModelOutput | |
| from typing import Iterable, Optional, Union, List | |
| import step1x3d_geometry | |
| from step1x3d_geometry.utils.typing import * | |
| from step1x3d_geometry.utils.misc import get_device | |
| from .base import BaseLabelEncoder | |
| DEFAULT_POSE = 0 # "unknown", "t-pose", "a-pose", uncond | |
| NUM_POSE_CLASSES = 3 | |
| POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3} | |
| DEFAULT_SYMMETRY_TYPE = 0 # "asymmetry", "x", uncond | |
| NUM_SYMMETRY_TYPE_CLASSES = 2 | |
| SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2} | |
| DEFAULT_GEOMETRY_QUALITY = 0 # "normal", "smooth", "sharp", uncond, | |
| NUM_GEOMETRY_QUALITY_CLASSES = 3 | |
| GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3} | |
| class LabelEncoder(BaseLabelEncoder, ModelMixin): | |
| """ | |
| Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
| Args: | |
| num_classes (`int`): The number of classes. | |
| hidden_size (`int`): The size of the vector embeddings. | |
| """ | |
| def configure(self) -> None: | |
| super().configure() | |
| if self.cfg.zero_uncond_embeds: | |
| self.embedding_table_tpose = nn.Embedding( | |
| NUM_POSE_CLASSES, self.cfg.hidden_size | |
| ) | |
| self.embedding_table_symmetry_type = nn.Embedding( | |
| NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size | |
| ) | |
| self.embedding_table_geometry_quality = nn.Embedding( | |
| NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size | |
| ) | |
| else: | |
| self.embedding_table_tpose = nn.Embedding( | |
| NUM_POSE_CLASSES + 1, self.cfg.hidden_size | |
| ) | |
| self.embedding_table_symmetry_type = nn.Embedding( | |
| NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size | |
| ) | |
| self.embedding_table_geometry_quality = nn.Embedding( | |
| NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size | |
| ) | |
| if self.cfg.zero_uncond_embeds: | |
| self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach() | |
| else: | |
| self.empty_label_embeds = ( | |
| self.encode_label( # the last class label is for the uncond | |
| [{"pose": "", "symetry": "", "geometry_type": ""}] | |
| ).detach() | |
| ) | |
| # load pretrained_model_name_or_path | |
| if self.cfg.pretrained_model_name_or_path is not None: | |
| print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") | |
| ckpt = torch.load( | |
| self.cfg.pretrained_model_name_or_path, map_location="cpu" | |
| )["state_dict"] | |
| pretrained_model_ckpt = {} | |
| for k, v in ckpt.items(): | |
| if k.startswith("label_condition."): | |
| pretrained_model_ckpt[k.replace("label_condition.", "")] = v | |
| self.load_state_dict(pretrained_model_ckpt, strict=True) | |
| def encode_label(self, labels: List[dict]) -> torch.FloatTensor: | |
| tpose_label_embeds = [] | |
| symmetry_type_label_embeds = [] | |
| geometry_quality_label_embeds = [] | |
| for label in labels: | |
| if "pose" in label.keys(): | |
| if label["pose"] is None or label["pose"] == "": | |
| tpose_label_embeds.append( | |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
| ) | |
| else: | |
| tpose_label_embeds.append( | |
| self.embedding_table_symmetry_type( | |
| torch.tensor(POSE_MAPPING[label["pose"][0]]).to( | |
| get_device() | |
| ) | |
| ) | |
| ) | |
| else: | |
| tpose_label_embeds.append( | |
| self.embedding_table_tpose( | |
| torch.tensor(DEFAULT_POSE).to(get_device()) | |
| ) | |
| ) | |
| if "symmetry" in label.keys(): | |
| if label["symmetry"] is None or label["symmetry"] == "": | |
| symmetry_type_label_embeds.append( | |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
| ) | |
| else: | |
| symmetry_type_label_embeds.append( | |
| self.embedding_table_symmetry_type( | |
| torch.tensor( | |
| SYMMETRY_TYPE_MAPPING[label["symmetry"][0]] | |
| ).to(get_device()) | |
| ) | |
| ) | |
| else: | |
| symmetry_type_label_embeds.append( | |
| self.embedding_table_symmetry_type( | |
| torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device()) | |
| ) | |
| ) | |
| if "geometry_type" in label.keys(): | |
| if label["geometry_type"] is None or label["geometry_type"] == "": | |
| geometry_quality_label_embeds.append( | |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
| ) | |
| else: | |
| geometry_quality_label_embeds.append( | |
| self.embedding_table_geometry_quality( | |
| torch.tensor( | |
| GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]] | |
| ).to(get_device()) | |
| ) | |
| ) | |
| else: | |
| geometry_quality_label_embeds.append( | |
| self.embedding_table_geometry_quality( | |
| torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device()) | |
| ) | |
| ) | |
| tpose_label_embeds = torch.stack(tpose_label_embeds) | |
| symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds) | |
| geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds) | |
| label_embeds = torch.stack( | |
| [ | |
| tpose_label_embeds, | |
| symmetry_type_label_embeds, | |
| geometry_quality_label_embeds, | |
| ], | |
| dim=1, | |
| ).to(self.dtype) | |
| return label_embeds | |