staswrs commited on
Commit
f27b461
·
1 Parent(s): 0b60769
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yuchen Lin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,143 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PartCrafter: Structured 3D Mesh Generation via Compositional Latent Diffusion Transformers
2
+
3
+ <h4 align="center">
4
+
5
+ [Yuchen Lin<sup>*</sup>](https://wgsxm.github.io), [Chenguo Lin<sup>*</sup>](https://chenguolin.github.io), [Panwang Pan<sup>†</sup>](https://paulpanwang.github.io), [Honglei Yan](https://openreview.net/profile?id=~Honglei_Yan1), [Yiqiang Feng](https://openreview.net/profile?id=~Feng_Yiqiang1), [Yadong Mu](http://www.muyadong.com), [Katerina Fragkiadaki](https://www.cs.cmu.edu/~katef/)
6
+
7
+ [![arXiv](https://img.shields.io/badge/arXiv-2506.05573-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2506.05573)
8
+ [![Project Page](https://img.shields.io/badge/🏠-Project%20Page-blue.svg)](https://wgsxm.github.io/projects/partcrafter)
9
+ [<img src="https://img.shields.io/badge/YouTube-Video-red" alt="YouTube">](https://www.youtube.com/watch?v=ZaZHbkkPtXY)
10
+ [![Model](https://img.shields.io/badge/🤗%20Model-PartCrafter-yellow.svg)](https://huggingface.co/wgsxm/PartCrafter)
11
+ [![License: MIT](https://img.shields.io/badge/📄%20License-MIT-green)](./LICENSE)
12
+
13
+ <p align="center">
14
+ <img width="90%" alt="pipeline", src="./assets/teaser.png">
15
+ </p>
16
+
17
+ </h4>
18
+
19
+ This repository contains the official implementation of the paper: [PartCrafter: Structured 3D Mesh Generation via Compositional Latent Diffusion Transformers](https://wgsxm.github.io/projects/partcrafter/).
20
+ PartCrafter is a structured 3D generative model that jointly generates multiple parts and objects from a single RGB image in one shot.
21
+ Here is our [Project Page](https://wgsxm.github.io/projects/partcrafter).
22
+
23
+ Feel free to contact me (linyuchen@stu.pku.edu.cn) or open an issue if you have any questions or suggestions.
24
+
25
+
26
+ ## 📢 News
27
+ - **2025-07-20**: A guide for installing PartCrafter on Windows is available in [this fork](https://github.com/JackDainzh/PartCrafter-Windows/tree/windows-main). Thanks to [JackDainzh](https://github.com/JackDainzh)!
28
+ - **2025-07-13**: PartCrafter is fully open-sourced 🚀.
29
+ - **2025-06-09**: PartCrafter is on arXiv.
30
+
31
+ ## 📋 TODO
32
+ - [x] Release inference scripts and pretrained checkpoints.
33
+ - [x] Release training code and data preprocessing scripts.
34
+ - [ ] Provide a HuggingFace🤗 demo.
35
+ - [ ] Release preprocessed dataset.
36
+
37
+ ## 🔧 Installation
38
+ We use `torch-2.5.1+cu124` and `python-3.11`. But it should also work with other versions. Create a conda environment with the following command (optional):
39
+ ```
40
+ conda create -n partcrafter python=3.11.13
41
+ conda activate partcrafter
42
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
43
+ ```
44
+ Then, install other dependencies with the following command:
45
+ ```
46
+ git clone https://github.com/wgsxm/PartCrafter.git
47
+ cd PartCrafter
48
+ bash settings/setup.sh
49
+ ```
50
+ If you do not have root access and use conda environment, you can install required graphics libraries with the following command:
51
+ ```
52
+ conda install -c conda-forge libegl libglu pyopengl
53
+ ```
54
+ We test the above installation on Debian 12 with NVIDIA H20 GPUs. For Windows users, you can try to set up the environment according to [this pull request](https://github.com/wgsxm/PartCrafter/pull/24) and [this fork](https://github.com/JackDainzh/PartCrafter-Windows/tree/windows-main). We sincerely thank [JackDainzh](https://github.com/JackDainzh) for contributing to the Windows support!
55
+
56
+ ## 💡 Quick Start
57
+ <p align="center">
58
+ <img width="90%" alt="pipeline", src="./assets/robot.gif">
59
+ </p>
60
+
61
+ Generate a 3D part-level object from an image:
62
+ ```
63
+ python scripts/inference_partcrafter.py \
64
+ --image_path assets/images/np3_2f6ab901c5a84ed6bbdf85a67b22a2ee.png \
65
+ --num_parts 3 --tag robot --render
66
+ ```
67
+ The required model weights will be automatically downloaded:
68
+ - PartCrafter model from [wgsxm/PartCrafter](https://huggingface.co/wgsxm/PartCrafter) → pretrained_weights/PartCrafter
69
+ - RMBG model from [briaai/RMBG-1.4](http://huggingface.co/briaai/RMBG-1.4) → pretrained_weights/RMBG-1.4
70
+
71
+ The generated results will be saved to `./results/robot`. We provide several example images from Objaverse and ABO in `./assets/images`. Their filenames start with recommended number of parts, e.g., `np3` which means 3 parts. You can also try other part count for the same input images.
72
+
73
+ Specify `--rmbg` if you use custom images. **This will remove the background of the input image and resize it appropriately.**
74
+
75
+ ## 💻 System Requirements
76
+ A CUDA-enabled GPU with at least 8GB VRAM. You can reduce number of parts or number of tokens to save GPU memory. We set the number of tokens per part to `1024` by default for better quality.
77
+
78
+ ## 📊 Dataset
79
+ Please refer to [Dataset README](./datasets/README.md) to download and preprocess the dataset. To generate a minimal dataset, you can run:
80
+ ```
81
+ python datasets/preprocess/preprocess.py --input assets/objects --output preprocessed_data
82
+ ```
83
+ This script preprocesses GLB files in `./assets/objects` and saves the preprocessed data to `./preprocessed_data`. We provide a pseudo data configuration [here](./datasets/object_part_configs.json), which makes use of the minimal preprocessed data and is compatible with the training settings.
84
+
85
+ ## 🦾 Training
86
+ To train PartCrafter from scratch, you first need to download TripoSG from [VAST-AI/TripoSG](https://huggingface.co/VAST-AI/TripoSG) and store the weights in `./pretrained_models/TripoSG`.
87
+ ```
88
+ huggingface-cli download VAST-AI/TripoSG --local-dir pretrained_weights/TripoSG
89
+ ```
90
+
91
+ Our training scripts are suitable for training with 8 H20 GPUs (96G VRAM each). Currently, we only finetune the DiT of TripoSG and keep the VAE fixed. But you can also finetune the VAE of TripoSG, which should improve the quality of the generated 3D parts. PartCrafter is compatible with all 3D object generative models based on vector sets such as [Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1). We warmly welcome pull requests from the community.
92
+
93
+ We provide several training configurations [here](./configs). You should modify the path of dataset configs in the training config files, which is currently set to `./datasets/object_part_configs.json`.
94
+
95
+ If you use `wandb`, you should also modify the `WANDB_API_KEY` in the training script. If you have trouble connecting to `wandb`, try `export WANDB_BASE_URL=https://api.bandw.top`.
96
+
97
+ Train PartCrafter from TripoSG:
98
+ ```
99
+ bash scripts/train_partcrafter.sh --config configs/mp8_nt512.yaml --use_ema \
100
+ --gradient_accumulation_steps 4 \
101
+ --output_dir output_partcrafter \
102
+ --tag scaleup_mp8_nt512
103
+ ```
104
+
105
+ Finetune PartCrafter with larger number of parts:
106
+ ```
107
+ bash scripts/train_partcrafter.sh --config configs/mp16_nt512.yaml --use_ema \
108
+ --gradient_accumulation_steps 4 \
109
+ --output_dir output_partcrafter \
110
+ --load_pretrained_model scaleup_mp8_nt512 \
111
+ --load_pretrained_model_ckpt 10 \
112
+ --tag scaleup_mp16_nt512
113
+ ```
114
+
115
+ Finetune PartCrafter with more tokens:
116
+ ```
117
+ bash scripts/train_partcrafter.sh --config configs/mp16_nt1024.yaml --use_ema \
118
+ --gradient_accumulation_steps 4 \
119
+ --output_dir output_partcrafter \
120
+ --load_pretrained_model scaleup_mp16_nt512 \
121
+ --load_pretrained_model_ckpt 10 \
122
+ --tag scaleup_mp16_nt1024
123
+ ```
124
+
125
+ ## 😊 Acknowledgement
126
+ We would like to thank the authors of [DiffSplat](https://chenguolin.github.io/projects/DiffSplat/), [TripoSG](https://yg256li.github.io/TripoSG-Page/), [HoloPart](https://vast-ai-research.github.io/HoloPart/), and [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
127
+ for their great work and generously providing source codes, which inspired our work and helped us a lot in the implementation.
128
+
129
+
130
+ ## 📚 Citation
131
+ If you find our work helpful, please consider citing:
132
+ ```bibtex
133
+ @misc{lin2025partcrafter,
134
+ title={PartCrafter: Structured 3D Mesh Generation via Compositional Latent Diffusion Transformers},
135
+ author={Yuchen Lin and Chenguo Lin and Panwang Pan and Honglei Yan and Yiqiang Feng and Yadong Mu and Katerina Fragkiadaki},
136
+ year={2025},
137
+ eprint={2506.05573},
138
+ url={https://arxiv.org/abs/2506.05573}
139
+ }
140
+ ```
141
+
142
+ ## 🌟 Star History
143
+ [![Star History Chart](https://api.star-history.com/svg?repos=wgsxm/PartCrafter&type=Date)](https://www.star-history.com/#wgsxm/PartCrafter&Date)
configs/mp16_nt1024.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
3
+ vae:
4
+ num_tokens: 1024
5
+ transformer:
6
+ enable_local_cross_attn: true
7
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
8
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
9
+
10
+
11
+ dataset:
12
+ config:
13
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
14
+ training_ratio: 0.9
15
+ min_num_parts: 1
16
+ max_num_parts: 16
17
+ max_iou_mean: 0.2
18
+ max_iou_max: 0.2
19
+ shuffle_parts: true
20
+ object_ratio: 0.3
21
+ rotating_ratio: 0.2
22
+ ratating_degree: 10
23
+
24
+ optimizer:
25
+ name: "adamw"
26
+ lr: 5e-5
27
+ betas:
28
+ - 0.9
29
+ - 0.999
30
+ weight_decay: 0.01
31
+ eps: 1.e-8
32
+
33
+ lr_scheduler:
34
+ name: "constant_warmup"
35
+ num_warmup_steps: 1000
36
+
37
+ train:
38
+ batch_size_per_gpu: 32
39
+ epochs: 10
40
+ grad_checkpoint: true
41
+ weighting_scheme: "logit_normal"
42
+ logit_mean: 0.0
43
+ logit_std: 1.0
44
+ mode_scale: 1.29
45
+ cfg_dropout_prob: 0.1
46
+ training_objective: "-v"
47
+ log_freq: 1
48
+ early_eval_freq: 500
49
+ early_eval: 1000
50
+ eval_freq: 1000
51
+ save_freq: 2000
52
+ eval_freq_epoch: 5
53
+ save_freq_epoch: 10
54
+ ema_kwargs:
55
+ decay: 0.9999
56
+ use_ema_warmup: true
57
+ inv_gamma: 1.
58
+ power: 0.75
59
+
60
+ val:
61
+ batch_size_per_gpu: 1
62
+ nrow: 4
63
+ min_num_parts: 2
64
+ max_num_parts: 8
65
+ num_inference_steps: 50
66
+ max_num_expanded_coords: 1e8
67
+ use_flash_decoder: false
68
+ rendering:
69
+ radius: 4.0
70
+ num_views: 36
71
+ fps: 18
72
+ metric:
73
+ cd_num_samples: 204800
74
+ cd_metric: "l2"
75
+ f1_score_threshold: 0.1
76
+ default_cd: 1e6
77
+ default_f1: 0.0
configs/mp16_nt512.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
3
+ vae:
4
+ num_tokens: 512
5
+ transformer:
6
+ enable_local_cross_attn: true
7
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
8
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
9
+
10
+
11
+ dataset:
12
+ config:
13
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
14
+ training_ratio: 0.9
15
+ min_num_parts: 1
16
+ max_num_parts: 16
17
+ max_iou_mean: 0.5
18
+ max_iou_max: 0.5
19
+ shuffle_parts: true
20
+ object_ratio: 0.3
21
+ rotating_ratio: 0.2
22
+ ratating_degree: 10
23
+
24
+ optimizer:
25
+ name: "adamw"
26
+ lr: 5e-5
27
+ betas:
28
+ - 0.9
29
+ - 0.999
30
+ weight_decay: 0.01
31
+ eps: 1.e-8
32
+
33
+ lr_scheduler:
34
+ name: "constant_warmup"
35
+ num_warmup_steps: 1000
36
+
37
+ train:
38
+ batch_size_per_gpu: 32
39
+ epochs: 10
40
+ grad_checkpoint: true
41
+ weighting_scheme: "logit_normal"
42
+ logit_mean: 0.0
43
+ logit_std: 1.0
44
+ mode_scale: 1.29
45
+ cfg_dropout_prob: 0.1
46
+ training_objective: "-v"
47
+ log_freq: 1
48
+ early_eval_freq: 500
49
+ early_eval: 1000
50
+ eval_freq: 1000
51
+ save_freq: 2000
52
+ eval_freq_epoch: 5
53
+ save_freq_epoch: 10
54
+ ema_kwargs:
55
+ decay: 0.9999
56
+ use_ema_warmup: true
57
+ inv_gamma: 1.
58
+ power: 0.75
59
+
60
+ val:
61
+ batch_size_per_gpu: 1
62
+ nrow: 4
63
+ min_num_parts: 2
64
+ max_num_parts: 8
65
+ num_inference_steps: 50
66
+ max_num_expanded_coords: 1e8
67
+ use_flash_decoder: false
68
+ rendering:
69
+ radius: 4.0
70
+ num_views: 36
71
+ fps: 18
72
+ metric:
73
+ cd_num_samples: 204800
74
+ cd_metric: "l2"
75
+ f1_score_threshold: 0.1
76
+ default_cd: 1e6
77
+ default_f1: 0.0
configs/mp8_nt512.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
3
+ vae:
4
+ num_tokens: 512
5
+ transformer:
6
+ enable_local_cross_attn: true
7
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
8
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
9
+
10
+
11
+ dataset:
12
+ config:
13
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
14
+ training_ratio: 0.9
15
+ min_num_parts: 1
16
+ max_num_parts: 8
17
+ max_iou_mean: 0.5
18
+ max_iou_max: 0.5
19
+ shuffle_parts: true
20
+ object_ratio: 0.3
21
+ rotating_ratio: 0.2
22
+ ratating_degree: 10
23
+
24
+ optimizer:
25
+ name: "adamw"
26
+ lr: 1e-4
27
+ betas:
28
+ - 0.9
29
+ - 0.999
30
+ weight_decay: 0.01
31
+ eps: 1.e-8
32
+
33
+ lr_scheduler:
34
+ name: "constant_warmup"
35
+ num_warmup_steps: 1000
36
+
37
+ train:
38
+ batch_size_per_gpu: 32
39
+ epochs: 10
40
+ grad_checkpoint: true
41
+ weighting_scheme: "logit_normal"
42
+ logit_mean: 0.0
43
+ logit_std: 1.0
44
+ mode_scale: 1.29
45
+ cfg_dropout_prob: 0.1
46
+ training_objective: "-v"
47
+ log_freq: 1
48
+ early_eval_freq: 500
49
+ early_eval: 1000
50
+ eval_freq: 1000
51
+ save_freq: 2000
52
+ eval_freq_epoch: 5
53
+ save_freq_epoch: 10
54
+ ema_kwargs:
55
+ decay: 0.9999
56
+ use_ema_warmup: true
57
+ inv_gamma: 1.
58
+ power: 0.75
59
+
60
+ val:
61
+ batch_size_per_gpu: 1
62
+ nrow: 4
63
+ min_num_parts: 2
64
+ max_num_parts: 8
65
+ num_inference_steps: 50
66
+ max_num_expanded_coords: 1e8
67
+ use_flash_decoder: false
68
+ rendering:
69
+ radius: 4.0
70
+ num_views: 36
71
+ fps: 18
72
+ metric:
73
+ cd_num_samples: 204800
74
+ cd_metric: "l2"
75
+ f1_score_threshold: 0.1
76
+ default_cd: 1e6
77
+ default_f1: 0.0
handler.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ pytorch-lightning
3
+ huggingface-hub
4
+ diffusers
5
+ transformers
6
+ omegaconf
7
+ trimesh
8
+ tqdm
9
+ pillow
scripts/inference_partcrafter.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from glob import glob
5
+ import time
6
+ from typing import Any, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import trimesh
11
+ from huggingface_hub import snapshot_download
12
+ from PIL import Image
13
+ from accelerate.utils import set_seed
14
+
15
+ from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
16
+ from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings
17
+ from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
18
+ from src.utils.image_utils import prepare_image
19
+ from src.models.briarmbg import BriaRMBG
20
+
21
+ @torch.no_grad()
22
+ def run_triposg(
23
+ pipe: Any,
24
+ image_input: Union[str, Image.Image],
25
+ num_parts: int,
26
+ rmbg_net: Any,
27
+ seed: int,
28
+ num_tokens: int = 1024,
29
+ num_inference_steps: int = 50,
30
+ guidance_scale: float = 7.0,
31
+ max_num_expanded_coords: int = 1e9,
32
+ use_flash_decoder: bool = False,
33
+ rmbg: bool = False,
34
+ dtype: torch.dtype = torch.float16,
35
+ device: str = "cuda",
36
+ ) -> trimesh.Scene:
37
+
38
+ if rmbg:
39
+ img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
40
+ else:
41
+ img_pil = Image.open(image_input)
42
+ start_time = time.time()
43
+ outputs = pipe(
44
+ image=[img_pil] * num_parts,
45
+ attention_kwargs={"num_parts": num_parts},
46
+ num_tokens=num_tokens,
47
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
48
+ num_inference_steps=num_inference_steps,
49
+ guidance_scale=guidance_scale,
50
+ max_num_expanded_coords=max_num_expanded_coords,
51
+ use_flash_decoder=use_flash_decoder,
52
+ ).meshes
53
+ end_time = time.time()
54
+ print(f"Time elapsed: {end_time - start_time:.2f} seconds")
55
+ for i in range(len(outputs)):
56
+ if outputs[i] is None:
57
+ # If the generated mesh is None (decoding error), use a dummy mesh
58
+ outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
59
+ return outputs, img_pil
60
+
61
+ MAX_NUM_PARTS = 16
62
+
63
+ if __name__ == "__main__":
64
+ device = "cuda"
65
+ dtype = torch.float16
66
+
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--image_path", type=str, required=True)
69
+ parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate")
70
+ parser.add_argument("--output_dir", type=str, default="./results")
71
+ parser.add_argument("--tag", type=str, default=None)
72
+ parser.add_argument("--seed", type=int, default=0)
73
+ parser.add_argument("--num_tokens", type=int, default=1024)
74
+ parser.add_argument("--num_inference_steps", type=int, default=50)
75
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
76
+ parser.add_argument("--max_num_expanded_coords", type=int, default=1e9)
77
+ parser.add_argument("--use_flash_decoder", action="store_true")
78
+ parser.add_argument("--rmbg", action="store_true")
79
+ parser.add_argument("--render", action="store_true")
80
+ args = parser.parse_args()
81
+
82
+ assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]"
83
+
84
+ # download pretrained weights
85
+ partcrafter_weights_dir = "pretrained_weights/PartCrafter"
86
+ rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
87
+ snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
88
+ snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
89
+
90
+ # init rmbg model for background removal
91
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
92
+ rmbg_net.eval()
93
+
94
+ # init tripoSG pipeline
95
+ pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype)
96
+
97
+ set_seed(args.seed)
98
+
99
+ # run inference
100
+ outputs, processed_image = run_triposg(
101
+ pipe,
102
+ image_input=args.image_path,
103
+ num_parts=args.num_parts,
104
+ rmbg_net=rmbg_net,
105
+ seed=args.seed,
106
+ num_tokens=args.num_tokens,
107
+ num_inference_steps=args.num_inference_steps,
108
+ guidance_scale=args.guidance_scale,
109
+ max_num_expanded_coords=args.max_num_expanded_coords,
110
+ use_flash_decoder=args.use_flash_decoder,
111
+ rmbg=args.rmbg,
112
+ dtype=dtype,
113
+ device=device,
114
+ )
115
+
116
+ if not os.path.exists(args.output_dir):
117
+ os.makedirs(args.output_dir)
118
+
119
+ if args.tag is None:
120
+ args.tag = time.strftime("%Y%m%d_%H_%M_%S")
121
+
122
+ export_dir = os.path.join(args.output_dir, args.tag)
123
+ os.makedirs(export_dir, exist_ok=True)
124
+
125
+ for i, mesh in enumerate(outputs):
126
+ mesh.export(os.path.join(export_dir, f"part_{i:02}.glb"))
127
+
128
+ merged_mesh = get_colored_mesh_composition(outputs)
129
+ merged_mesh.export(os.path.join(export_dir, "object.glb"))
130
+ print(f"Generated {len(outputs)} parts and saved to {export_dir}")
131
+
132
+ if args.render:
133
+ print("Start rendering...")
134
+ num_views = 36
135
+ radius = 4
136
+ fps = 18
137
+ rendered_images = render_views_around_mesh(
138
+ merged_mesh,
139
+ num_views=num_views,
140
+ radius=radius,
141
+ )
142
+ rendered_normals = render_normal_views_around_mesh(
143
+ merged_mesh,
144
+ num_views=num_views,
145
+ radius=radius,
146
+ )
147
+ rendered_grids = make_grid_for_images_or_videos(
148
+ [
149
+ [processed_image] * num_views,
150
+ rendered_images,
151
+ rendered_normals,
152
+ ],
153
+ nrow=3
154
+ )
155
+ export_renderings(
156
+ rendered_images,
157
+ os.path.join(export_dir, "rendering.gif"),
158
+ fps=fps,
159
+ )
160
+ export_renderings(
161
+ rendered_normals,
162
+ os.path.join(export_dir, "rendering_normal.gif"),
163
+ fps=fps,
164
+ )
165
+ export_renderings(
166
+ rendered_grids,
167
+ os.path.join(export_dir, "rendering_grid.gif"),
168
+ fps=fps,
169
+ )
170
+
171
+ rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0]
172
+ rendered_image.save(os.path.join(export_dir, "rendering.png"))
173
+ rendered_normal.save(os.path.join(export_dir, "rendering_normal.png"))
174
+ rendered_grid.save(os.path.join(export_dir, "rendering_grid.png"))
175
+ print("Rendering done.")
176
+
scripts/train_partcrafter.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_MACHINES=1
2
+ NUM_LOCAL_GPUS=8
3
+ MACHINE_RANK=0
4
+
5
+ export WANDB_API_KEY="" # Modify this if you use wandb
6
+
7
+ accelerate launch \
8
+ --num_machines $NUM_MACHINES \
9
+ --num_processes $(( $NUM_MACHINES * $NUM_LOCAL_GPUS )) \
10
+ --machine_rank $MACHINE_RANK \
11
+ src/train_partcrafter.py \
12
+ --pin_memory \
13
+ --allow_tf32 \
14
+ $@
settings/requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scikit-learn
2
+ gpustat
3
+ nvitop
4
+ diffusers
5
+ transformers
6
+ einops
7
+ huggingface_hub
8
+ opencv-python
9
+ trimesh
10
+ omegaconf
11
+ scikit-image
12
+ numpy==1.26.4
13
+ peft
14
+ jaxtyping
15
+ typeguard
16
+ matplotlib
17
+ imageio-ffmpeg
18
+ pyrender
19
+ deepspeed
20
+ wandb[media]
21
+ colormaps
settings/setup.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pip install torch-cluster -f https://data.pyg.org/whl/torch-2.5.1+cu124.html
2
+ pip install -r settings/requirements.txt
3
+ sudo apt-get install libegl1 libegl1-mesa libgl1-mesa-dev -y # for rendering
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/datasets/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import torch
4
+
5
+ from .objaverse_part import ObjaversePartDataset, BatchedObjaversePartDataset
6
+
7
+ # Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/loader.py
8
+ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
9
+
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self._DataLoader__initialized = False
13
+ if self.batch_sampler is None:
14
+ self.sampler = _RepeatSampler(self.sampler)
15
+ else:
16
+ self.batch_sampler = _RepeatSampler(self.batch_sampler)
17
+ self._DataLoader__initialized = True
18
+ self.iterator = super().__iter__()
19
+
20
+ def __len__(self):
21
+ return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
22
+
23
+ def __iter__(self):
24
+ for i in range(len(self)):
25
+ yield next(self.iterator)
26
+
27
+
28
+ class _RepeatSampler(object):
29
+ """ Sampler that repeats forever.
30
+
31
+ Args:
32
+ sampler (Sampler)
33
+ """
34
+
35
+ def __init__(self, sampler):
36
+ self.sampler = sampler
37
+ if isinstance(self.sampler, torch.utils.data.sampler.BatchSampler):
38
+ self.batch_size = self.sampler.batch_size
39
+ self.drop_last = self.sampler.drop_last
40
+
41
+ def __len__(self):
42
+ return len(self.sampler)
43
+
44
+ def __iter__(self):
45
+ while True:
46
+ yield from iter(self.sampler)
47
+
48
+ def yield_forever(iterator: Iterator[Any]):
49
+ while True:
50
+ for x in iterator:
51
+ yield x
src/datasets/objaverse_part.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import json
4
+ import os
5
+ import random
6
+
7
+ import accelerate
8
+ import torch
9
+ from torchvision import transforms
10
+ import numpy as np
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+
14
+ from src.utils.data_utils import load_surface, load_surfaces
15
+
16
+ class ObjaversePartDataset(torch.utils.data.Dataset):
17
+ def __init__(
18
+ self,
19
+ configs: DictConfig,
20
+ training: bool = True,
21
+ ):
22
+ super().__init__()
23
+ self.configs = configs
24
+ self.training = training
25
+
26
+ self.min_num_parts = configs['dataset']['min_num_parts']
27
+ self.max_num_parts = configs['dataset']['max_num_parts']
28
+ self.val_min_num_parts = configs['val']['min_num_parts']
29
+ self.val_max_num_parts = configs['val']['max_num_parts']
30
+
31
+ self.max_iou_mean = configs['dataset'].get('max_iou_mean', None)
32
+ self.max_iou_max = configs['dataset'].get('max_iou_max', None)
33
+
34
+ self.shuffle_parts = configs['dataset']['shuffle_parts']
35
+ self.training_ratio = configs['dataset']['training_ratio']
36
+ self.balance_object_and_parts = configs['dataset'].get('balance_object_and_parts', False)
37
+
38
+ self.rotating_ratio = configs['dataset'].get('rotating_ratio', 0.0)
39
+ self.rotating_degree = configs['dataset'].get('rotating_degree', 10.0)
40
+ self.transform = transforms.Compose([
41
+ transforms.RandomRotation(degrees=(-self.rotating_degree, self.rotating_degree), fill=(255, 255, 255)),
42
+ ])
43
+
44
+ if isinstance(configs['dataset']['config'], ListConfig):
45
+ data_configs = []
46
+ for config in configs['dataset']['config']:
47
+ local_data_configs = json.load(open(config))
48
+ if self.balance_object_and_parts:
49
+ if self.training:
50
+ local_data_configs = local_data_configs[:int(len(local_data_configs) * self.training_ratio)]
51
+ else:
52
+ local_data_configs = local_data_configs[int(len(local_data_configs) * self.training_ratio):]
53
+ local_data_configs = [config for config in local_data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
54
+ data_configs += local_data_configs
55
+ else:
56
+ data_configs = json.load(open(configs['dataset']['config']))
57
+ data_configs = [config for config in data_configs if config['valid']]
58
+ data_configs = [config for config in data_configs if self.min_num_parts <= config['num_parts'] <= self.max_num_parts]
59
+ if self.max_iou_mean is not None and self.max_iou_max is not None:
60
+ data_configs = [config for config in data_configs if config['iou_mean'] <= self.max_iou_mean]
61
+ data_configs = [config for config in data_configs if config['iou_max'] <= self.max_iou_max]
62
+ if not self.balance_object_and_parts:
63
+ if self.training:
64
+ data_configs = data_configs[:int(len(data_configs) * self.training_ratio)]
65
+ else:
66
+ data_configs = data_configs[int(len(data_configs) * self.training_ratio):]
67
+ data_configs = [config for config in data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
68
+ self.data_configs = data_configs
69
+ self.image_size = (512, 512)
70
+
71
+ def __len__(self) -> int:
72
+ return len(self.data_configs)
73
+
74
+ def _get_data_by_config(self, data_config):
75
+ if 'surface_path' in data_config:
76
+ surface_path = data_config['surface_path']
77
+ surface_data = np.load(surface_path, allow_pickle=True).item()
78
+ # If parts is empty, the object is the only part
79
+ part_surfaces = surface_data['parts'] if len(surface_data['parts']) > 0 else [surface_data['object']]
80
+ if self.shuffle_parts:
81
+ random.shuffle(part_surfaces)
82
+ part_surfaces = load_surfaces(part_surfaces) # [N, P, 6]
83
+ else:
84
+ part_surfaces = []
85
+ for surface_path in data_config['surface_paths']:
86
+ surface_data = np.load(surface_path, allow_pickle=True).item()
87
+ part_surfaces.append(load_surface(surface_data))
88
+ part_surfaces = torch.stack(part_surfaces, dim=0) # [N, P, 6]
89
+ image_path = data_config['image_path']
90
+ image = Image.open(image_path).resize(self.image_size)
91
+ if random.random() < self.rotating_ratio:
92
+ image = self.transform(image)
93
+ image = np.array(image)
94
+ image = torch.from_numpy(image).to(torch.uint8) # [H, W, 3]
95
+ images = torch.stack([image] * part_surfaces.shape[0], dim=0) # [N, H, W, 3]
96
+ return {
97
+ "images": images,
98
+ "part_surfaces": part_surfaces,
99
+ }
100
+
101
+ def __getitem__(self, idx: int):
102
+ # The dataset can only support batchsize == 1 training.
103
+ # Because the number of parts is not fixed.
104
+ # Please see BatchedObjaversePartDataset for batched training.
105
+ data_config = self.data_configs[idx]
106
+ data = self._get_data_by_config(data_config)
107
+ return data
108
+
109
+ class BatchedObjaversePartDataset(ObjaversePartDataset):
110
+ def __init__(
111
+ self,
112
+ configs: DictConfig,
113
+ batch_size: int,
114
+ is_main_process: bool = False,
115
+ shuffle: bool = True,
116
+ training: bool = True,
117
+ ):
118
+ assert training
119
+ assert batch_size > 1
120
+ super().__init__(configs, training)
121
+ self.batch_size = batch_size
122
+ self.is_main_process = is_main_process
123
+ if batch_size < self.max_num_parts:
124
+ self.data_configs = [config for config in self.data_configs if config['num_parts'] <= batch_size]
125
+
126
+ if shuffle:
127
+ random.shuffle(self.data_configs)
128
+
129
+ self.object_configs = [config for config in self.data_configs if config['num_parts'] == 1]
130
+ self.parts_configs = [config for config in self.data_configs if config['num_parts'] > 1]
131
+
132
+ self.object_ratio = configs['dataset']['object_ratio']
133
+ # Here we keep the ratio of object to parts
134
+ self.object_configs = self.object_configs[:int(len(self.parts_configs) * self.object_ratio)]
135
+
136
+ dropped_data_configs = self.parts_configs + self.object_configs
137
+ if shuffle:
138
+ random.shuffle(dropped_data_configs)
139
+
140
+ self.data_configs = self._get_batched_configs(dropped_data_configs, batch_size)
141
+
142
+ def _get_batched_configs(self, data_configs, batch_size):
143
+ batched_data_configs = []
144
+ num_data_configs = len(data_configs)
145
+ progress_bar = tqdm(
146
+ range(len(data_configs)),
147
+ desc="Batching Dataset",
148
+ ncols=125,
149
+ disable=not self.is_main_process,
150
+ )
151
+ while len(data_configs) > 0:
152
+ temp_batch = []
153
+ temp_num_parts = 0
154
+ unchosen_configs = []
155
+ while temp_num_parts < batch_size and len(data_configs) > 0:
156
+ config = data_configs.pop() # pop the last config
157
+ num_parts = config['num_parts']
158
+ if temp_num_parts + num_parts <= batch_size:
159
+ temp_batch.append(config)
160
+ temp_num_parts += num_parts
161
+ progress_bar.update(1)
162
+ else:
163
+ unchosen_configs.append(config) # add back to the end
164
+ data_configs = data_configs + unchosen_configs # concat the unchosen configs
165
+ if temp_num_parts == batch_size:
166
+ # Successfully get a batch
167
+ if len(temp_batch) < batch_size:
168
+ # pad the batch
169
+ temp_batch += [{}] * (batch_size - len(temp_batch))
170
+ batched_data_configs += temp_batch
171
+ # Else, the code enters here because len(data_configs) == 0
172
+ # which means in the left data_configs, there are no enough
173
+ # "suitable" configs to form a batch.
174
+ # Thus, drop the uncompleted batch.
175
+ progress_bar.close()
176
+ return batched_data_configs
177
+
178
+ def __getitem__(self, idx: int):
179
+ data_config = self.data_configs[idx]
180
+ if len(data_config) == 0:
181
+ # placeholder
182
+ return {}
183
+ data = self._get_data_by_config(data_config)
184
+ return data
185
+
186
+ def collate_fn(self, batch):
187
+ batch = [data for data in batch if len(data) > 0]
188
+ images = torch.cat([data['images'] for data in batch], dim=0) # [N, H, W, 3]
189
+ surfaces = torch.cat([data['part_surfaces'] for data in batch], dim=0) # [N, P, 6]
190
+ num_parts = torch.LongTensor([data['part_surfaces'].shape[0] for data in batch])
191
+ assert images.shape[0] == surfaces.shape[0] == num_parts.sum() == self.batch_size
192
+ batch = {
193
+ "images": images,
194
+ "part_surfaces": surfaces,
195
+ "num_parts": num_parts,
196
+ }
197
+ return batch
src/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/models/attention_processor.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import logging
7
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
8
+ from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
9
+ from einops import rearrange
10
+ from torch import nn
11
+
12
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
13
+
14
+ class FlashTripo2AttnProcessor2_0:
15
+ r"""
16
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
17
+ used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
18
+ """
19
+
20
+ def __init__(self, topk=True):
21
+ if not hasattr(F, "scaled_dot_product_attention"):
22
+ raise ImportError(
23
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
24
+ )
25
+ self.topk = topk
26
+
27
+ def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
28
+ if k.shape[-2] == 3072:
29
+ topk = 1024
30
+ elif k.shape[-2] == 512:
31
+ topk = 256
32
+ else:
33
+ topk = k.shape[-2] // 3
34
+
35
+ if self.topk is True:
36
+ q1 = q[:, :, ::100, :]
37
+ sim = q1 @ k.transpose(-1, -2)
38
+ sim = torch.mean(sim, -2)
39
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
40
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
41
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
42
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
43
+ out = F.scaled_dot_product_attention(q, k0, v0)
44
+ elif self.topk is False:
45
+ out = F.scaled_dot_product_attention(q, k, v)
46
+ else:
47
+ idx, counts = self.topk
48
+ start = 0
49
+ outs = []
50
+ for grid_coord, count in zip(idx, counts):
51
+ end = start + count
52
+ q_chunk = q[:, :, start:end, :]
53
+ q1 = q_chunk[:, :, ::50, :]
54
+ sim = q1 @ k.transpose(-1, -2)
55
+ sim = torch.mean(sim, -2)
56
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
57
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
58
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
59
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
60
+ out = F.scaled_dot_product_attention(q_chunk, k0, v0)
61
+ outs.append(out)
62
+ start += count
63
+ out = torch.cat(outs, dim=-2)
64
+ self.topk = False
65
+ return out
66
+
67
+ def __call__(
68
+ self,
69
+ attn: Attention,
70
+ hidden_states: torch.Tensor,
71
+ encoder_hidden_states: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ temb: Optional[torch.Tensor] = None,
74
+ image_rotary_emb: Optional[torch.Tensor] = None,
75
+ ) -> torch.Tensor:
76
+ from diffusers.models.embeddings import apply_rotary_emb
77
+
78
+ residual = hidden_states
79
+ if attn.spatial_norm is not None:
80
+ hidden_states = attn.spatial_norm(hidden_states, temb)
81
+
82
+ input_ndim = hidden_states.ndim
83
+
84
+ if input_ndim == 4:
85
+ batch_size, channel, height, width = hidden_states.shape
86
+ hidden_states = hidden_states.view(
87
+ batch_size, channel, height * width
88
+ ).transpose(1, 2)
89
+
90
+ batch_size, sequence_length, _ = (
91
+ hidden_states.shape
92
+ if encoder_hidden_states is None
93
+ else encoder_hidden_states.shape
94
+ )
95
+
96
+ if attention_mask is not None:
97
+ attention_mask = attn.prepare_attention_mask(
98
+ attention_mask, sequence_length, batch_size
99
+ )
100
+ # scaled_dot_product_attention expects attention_mask shape to be
101
+ # (batch, heads, source_length, target_length)
102
+ attention_mask = attention_mask.view(
103
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
104
+ )
105
+
106
+ if attn.group_norm is not None:
107
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
108
+ 1, 2
109
+ )
110
+
111
+ query = attn.to_q(hidden_states)
112
+
113
+ if encoder_hidden_states is None:
114
+ encoder_hidden_states = hidden_states
115
+ elif attn.norm_cross:
116
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
117
+ encoder_hidden_states
118
+ )
119
+
120
+ key = attn.to_k(encoder_hidden_states)
121
+ value = attn.to_v(encoder_hidden_states)
122
+
123
+ # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
124
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
125
+ if not attn.is_cross_attention:
126
+ qkv = torch.cat((query, key, value), dim=-1)
127
+ split_size = qkv.shape[-1] // attn.heads // 3
128
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
129
+ query, key, value = torch.split(qkv, split_size, dim=-1)
130
+ else:
131
+ kv = torch.cat((key, value), dim=-1)
132
+ split_size = kv.shape[-1] // attn.heads // 2
133
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
134
+ key, value = torch.split(kv, split_size, dim=-1)
135
+
136
+ head_dim = key.shape[-1]
137
+
138
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
139
+
140
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
141
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
142
+
143
+ if attn.norm_q is not None:
144
+ query = attn.norm_q(query)
145
+ if attn.norm_k is not None:
146
+ key = attn.norm_k(key)
147
+
148
+ # Apply RoPE if needed
149
+ if image_rotary_emb is not None:
150
+ query = apply_rotary_emb(query, image_rotary_emb)
151
+ if not attn.is_cross_attention:
152
+ key = apply_rotary_emb(key, image_rotary_emb)
153
+
154
+ # flashvdm topk
155
+ hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
156
+
157
+ hidden_states = hidden_states.transpose(1, 2).reshape(
158
+ batch_size, -1, attn.heads * head_dim
159
+ )
160
+ hidden_states = hidden_states.to(query.dtype)
161
+
162
+ # linear proj
163
+ hidden_states = attn.to_out[0](hidden_states)
164
+ # dropout
165
+ hidden_states = attn.to_out[1](hidden_states)
166
+
167
+ if input_ndim == 4:
168
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
169
+ batch_size, channel, height, width
170
+ )
171
+
172
+ if attn.residual_connection:
173
+ hidden_states = hidden_states + residual
174
+
175
+ hidden_states = hidden_states / attn.rescale_output_factor
176
+
177
+ return hidden_states
178
+
179
+ class TripoSGAttnProcessor2_0:
180
+ r"""
181
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
182
+ used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
183
+ """
184
+
185
+ def __init__(self):
186
+ if not hasattr(F, "scaled_dot_product_attention"):
187
+ raise ImportError(
188
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
189
+ )
190
+
191
+ def __call__(
192
+ self,
193
+ attn: Attention,
194
+ hidden_states: torch.Tensor,
195
+ encoder_hidden_states: Optional[torch.Tensor] = None,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ temb: Optional[torch.Tensor] = None,
198
+ image_rotary_emb: Optional[torch.Tensor] = None,
199
+ ) -> torch.Tensor:
200
+ from diffusers.models.embeddings import apply_rotary_emb
201
+
202
+ residual = hidden_states
203
+ if attn.spatial_norm is not None:
204
+ hidden_states = attn.spatial_norm(hidden_states, temb)
205
+
206
+ input_ndim = hidden_states.ndim
207
+
208
+ if input_ndim == 4:
209
+ batch_size, channel, height, width = hidden_states.shape
210
+ hidden_states = hidden_states.view(
211
+ batch_size, channel, height * width
212
+ ).transpose(1, 2)
213
+
214
+ batch_size, sequence_length, _ = (
215
+ hidden_states.shape
216
+ if encoder_hidden_states is None
217
+ else encoder_hidden_states.shape
218
+ )
219
+
220
+ if attention_mask is not None:
221
+ attention_mask = attn.prepare_attention_mask(
222
+ attention_mask, sequence_length, batch_size
223
+ )
224
+ # scaled_dot_product_attention expects attention_mask shape to be
225
+ # (batch, heads, source_length, target_length)
226
+ attention_mask = attention_mask.view(
227
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
228
+ )
229
+
230
+ if attn.group_norm is not None:
231
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
232
+ 1, 2
233
+ )
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
241
+ encoder_hidden_states
242
+ )
243
+
244
+ key = attn.to_k(encoder_hidden_states)
245
+ value = attn.to_v(encoder_hidden_states)
246
+
247
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
248
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
249
+ if not attn.is_cross_attention:
250
+ qkv = torch.cat((query, key, value), dim=-1)
251
+ split_size = qkv.shape[-1] // attn.heads // 3
252
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
253
+ query, key, value = torch.split(qkv, split_size, dim=-1)
254
+ else:
255
+ kv = torch.cat((key, value), dim=-1)
256
+ split_size = kv.shape[-1] // attn.heads // 2
257
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
258
+ key, value = torch.split(kv, split_size, dim=-1)
259
+
260
+ head_dim = key.shape[-1]
261
+
262
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
265
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
266
+
267
+ if attn.norm_q is not None:
268
+ query = attn.norm_q(query)
269
+ if attn.norm_k is not None:
270
+ key = attn.norm_k(key)
271
+
272
+ # Apply RoPE if needed
273
+ if image_rotary_emb is not None:
274
+ query = apply_rotary_emb(query, image_rotary_emb)
275
+ if not attn.is_cross_attention:
276
+ key = apply_rotary_emb(key, image_rotary_emb)
277
+
278
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
279
+ # TODO: add support for attn.scale when we move to Torch 2.1
280
+ hidden_states = F.scaled_dot_product_attention(
281
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
282
+ )
283
+
284
+ hidden_states = hidden_states.transpose(1, 2).reshape(
285
+ batch_size, -1, attn.heads * head_dim
286
+ )
287
+ hidden_states = hidden_states.to(query.dtype)
288
+
289
+ # linear proj
290
+ hidden_states = attn.to_out[0](hidden_states)
291
+ # dropout
292
+ hidden_states = attn.to_out[1](hidden_states)
293
+
294
+ if input_ndim == 4:
295
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
296
+ batch_size, channel, height, width
297
+ )
298
+
299
+ if attn.residual_connection:
300
+ hidden_states = hidden_states + residual
301
+
302
+ hidden_states = hidden_states / attn.rescale_output_factor
303
+
304
+ return hidden_states
305
+
306
+
307
+ class FusedTripoSGAttnProcessor2_0:
308
+ r"""
309
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
310
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
311
+ query and key vector.
312
+ """
313
+
314
+ def __init__(self):
315
+ if not hasattr(F, "scaled_dot_product_attention"):
316
+ raise ImportError(
317
+ "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
318
+ )
319
+
320
+ def __call__(
321
+ self,
322
+ attn: Attention,
323
+ hidden_states: torch.Tensor,
324
+ encoder_hidden_states: Optional[torch.Tensor] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ temb: Optional[torch.Tensor] = None,
327
+ image_rotary_emb: Optional[torch.Tensor] = None,
328
+ ) -> torch.Tensor:
329
+ from diffusers.models.embeddings import apply_rotary_emb
330
+
331
+ residual = hidden_states
332
+ if attn.spatial_norm is not None:
333
+ hidden_states = attn.spatial_norm(hidden_states, temb)
334
+
335
+ input_ndim = hidden_states.ndim
336
+
337
+ if input_ndim == 4:
338
+ batch_size, channel, height, width = hidden_states.shape
339
+ hidden_states = hidden_states.view(
340
+ batch_size, channel, height * width
341
+ ).transpose(1, 2)
342
+
343
+ batch_size, sequence_length, _ = (
344
+ hidden_states.shape
345
+ if encoder_hidden_states is None
346
+ else encoder_hidden_states.shape
347
+ )
348
+
349
+ if attention_mask is not None:
350
+ attention_mask = attn.prepare_attention_mask(
351
+ attention_mask, sequence_length, batch_size
352
+ )
353
+ # scaled_dot_product_attention expects attention_mask shape to be
354
+ # (batch, heads, source_length, target_length)
355
+ attention_mask = attention_mask.view(
356
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
357
+ )
358
+
359
+ if attn.group_norm is not None:
360
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
361
+ 1, 2
362
+ )
363
+
364
+ # NOTE that pre-trained split heads first, then split qkv
365
+ if encoder_hidden_states is None:
366
+ qkv = attn.to_qkv(hidden_states)
367
+ split_size = qkv.shape[-1] // attn.heads // 3
368
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
369
+ query, key, value = torch.split(qkv, split_size, dim=-1)
370
+ else:
371
+ if attn.norm_cross:
372
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
373
+ encoder_hidden_states
374
+ )
375
+ query = attn.to_q(hidden_states)
376
+
377
+ kv = attn.to_kv(encoder_hidden_states)
378
+ split_size = kv.shape[-1] // attn.heads // 2
379
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
380
+ key, value = torch.split(kv, split_size, dim=-1)
381
+
382
+ head_dim = key.shape[-1]
383
+
384
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
386
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
387
+
388
+ if attn.norm_q is not None:
389
+ query = attn.norm_q(query)
390
+ if attn.norm_k is not None:
391
+ key = attn.norm_k(key)
392
+
393
+ # Apply RoPE if needed
394
+ if image_rotary_emb is not None:
395
+ query = apply_rotary_emb(query, image_rotary_emb)
396
+ if not attn.is_cross_attention:
397
+ key = apply_rotary_emb(key, image_rotary_emb)
398
+
399
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
400
+ # TODO: add support for attn.scale when we move to Torch 2.1
401
+ hidden_states = F.scaled_dot_product_attention(
402
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
403
+ )
404
+
405
+ hidden_states = hidden_states.transpose(1, 2).reshape(
406
+ batch_size, -1, attn.heads * head_dim
407
+ )
408
+ hidden_states = hidden_states.to(query.dtype)
409
+
410
+ # linear proj
411
+ hidden_states = attn.to_out[0](hidden_states)
412
+ # dropout
413
+ hidden_states = attn.to_out[1](hidden_states)
414
+
415
+ if input_ndim == 4:
416
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
417
+ batch_size, channel, height, width
418
+ )
419
+
420
+ if attn.residual_connection:
421
+ hidden_states = hidden_states + residual
422
+
423
+ hidden_states = hidden_states / attn.rescale_output_factor
424
+
425
+ return hidden_states
426
+
427
+ # Modified from https://github.com/VAST-AI-Research/MIDI-3D/blob/main/midi/models/attention_processor.py#L264
428
+ class PartCrafterAttnProcessor:
429
+ r"""
430
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
431
+ used in the PartCrafter model. It applies a normalization layer and rotary embedding on query and key vector.
432
+ """
433
+
434
+ def __init__(self):
435
+ if not hasattr(F, "scaled_dot_product_attention"):
436
+ raise ImportError(
437
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
438
+ )
439
+
440
+
441
+ def __call__(
442
+ self,
443
+ attn: Attention,
444
+ hidden_states: torch.Tensor,
445
+ encoder_hidden_states: Optional[torch.Tensor] = None,
446
+ attention_mask: Optional[torch.Tensor] = None,
447
+ temb: Optional[torch.Tensor] = None,
448
+ image_rotary_emb: Optional[torch.Tensor] = None,
449
+ num_parts: Optional[Union[int, torch.Tensor]] = None,
450
+ ) -> torch.Tensor:
451
+ from diffusers.models.embeddings import apply_rotary_emb
452
+
453
+ residual = hidden_states
454
+ if attn.spatial_norm is not None:
455
+ hidden_states = attn.spatial_norm(hidden_states, temb)
456
+
457
+ input_ndim = hidden_states.ndim
458
+
459
+ if input_ndim == 4:
460
+ batch_size, channel, height, width = hidden_states.shape
461
+ hidden_states = hidden_states.view(
462
+ batch_size, channel, height * width
463
+ ).transpose(1, 2)
464
+
465
+ batch_size, sequence_length, _ = (
466
+ hidden_states.shape
467
+ if encoder_hidden_states is None
468
+ else encoder_hidden_states.shape
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ attention_mask = attn.prepare_attention_mask(
473
+ attention_mask, sequence_length, batch_size
474
+ )
475
+ # scaled_dot_product_attention expects attention_mask shape to be
476
+ # (batch, heads, source_length, target_length)
477
+ attention_mask = attention_mask.view(
478
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
479
+ )
480
+
481
+ if attn.group_norm is not None:
482
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
483
+ 1, 2
484
+ )
485
+
486
+ query = attn.to_q(hidden_states)
487
+
488
+ if encoder_hidden_states is None:
489
+ encoder_hidden_states = hidden_states
490
+ elif attn.norm_cross:
491
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
492
+ encoder_hidden_states
493
+ )
494
+
495
+ key = attn.to_k(encoder_hidden_states)
496
+ value = attn.to_v(encoder_hidden_states)
497
+
498
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
499
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
500
+ if not attn.is_cross_attention:
501
+ qkv = torch.cat((query, key, value), dim=-1)
502
+ split_size = qkv.shape[-1] // attn.heads // 3
503
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
504
+ query, key, value = torch.split(qkv, split_size, dim=-1)
505
+ else:
506
+ kv = torch.cat((key, value), dim=-1)
507
+ split_size = kv.shape[-1] // attn.heads // 2
508
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
509
+ key, value = torch.split(kv, split_size, dim=-1)
510
+
511
+ head_dim = key.shape[-1]
512
+
513
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
514
+
515
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
516
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
517
+
518
+ if attn.norm_q is not None:
519
+ query = attn.norm_q(query)
520
+ if attn.norm_k is not None:
521
+ key = attn.norm_k(key)
522
+
523
+ # Apply RoPE if needed
524
+ if image_rotary_emb is not None:
525
+ query = apply_rotary_emb(query, image_rotary_emb)
526
+ if not attn.is_cross_attention:
527
+ key = apply_rotary_emb(key, image_rotary_emb)
528
+
529
+ if isinstance(num_parts, torch.Tensor):
530
+ # Assume list in training, do not consider classifier-free guidance
531
+ idx = 0
532
+ hidden_states_list = []
533
+ for n_p in num_parts:
534
+ k = key[idx : idx + n_p]
535
+ v = value[idx : idx + n_p]
536
+ q = query[idx : idx + n_p]
537
+ idx += n_p
538
+ if k.shape[2] == q.shape[2]:
539
+ # Assuming self-attention
540
+ # Here 'b' is always 1
541
+ k = rearrange(
542
+ k, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
543
+ ) # [b, h, ni*nt, c]
544
+ v = rearrange(
545
+ v, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
546
+ ) # [b, h, ni*nt, c]
547
+ else:
548
+ # Assuming cross-attention
549
+ # Here 'b' is always 1
550
+ k = k[::n_p] # [b, h, nt, c]
551
+ v = v[::n_p] # [b, h, nt, c]
552
+ # Here 'b' is always 1
553
+ q = rearrange(
554
+ q, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
555
+ ) # [b, h, ni*nt, c]
556
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
557
+ h_s = F.scaled_dot_product_attention(
558
+ q, k, v,
559
+ dropout_p=0.0,
560
+ is_causal=False,
561
+ )
562
+ h_s = h_s.transpose(1, 2).reshape(
563
+ n_p, -1, attn.heads * head_dim
564
+ )
565
+ h_s = h_s.to(query.dtype)
566
+ hidden_states_list.append(h_s)
567
+ hidden_states = torch.cat(hidden_states_list, dim=0)
568
+
569
+ elif isinstance(num_parts, int):
570
+ # Assume single instance
571
+ if key.shape[2] == query.shape[2]:
572
+ # Assuming self-attention
573
+ # Here we need 'b' when using classifier-free guidance
574
+ key = rearrange(
575
+ key, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
576
+ ) # [b, h, ni*nt, c]
577
+ value = rearrange(
578
+ value, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
579
+ ) # [b, h, ni*nt, c]
580
+ else:
581
+ # Assuming cross-attention
582
+ # Here we need 'b' when using classifier-free guidance
583
+ # Control signal is repeated ni times within each (b, ni)
584
+ # We select only the first instance per group
585
+ key = key[::num_parts] # [b, h, nt, c]
586
+ value = value[::num_parts] # [b, h, nt, c]
587
+ query = rearrange(
588
+ query, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
589
+ ) # [b, h, ni*nt, c]
590
+
591
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
592
+ hidden_states = F.scaled_dot_product_attention(
593
+ query,
594
+ key,
595
+ value,
596
+ dropout_p=0.0,
597
+ is_causal=False,
598
+ )
599
+ hidden_states = hidden_states.transpose(1, 2).reshape(
600
+ batch_size, -1, attn.heads * head_dim
601
+ )
602
+ hidden_states = hidden_states.to(query.dtype)
603
+
604
+ else:
605
+ raise ValueError(
606
+ "num_parts must be a torch.Tensor or int, but got {}".format(type(num_parts))
607
+ )
608
+
609
+ # linear proj
610
+ hidden_states = attn.to_out[0](hidden_states)
611
+ # dropout
612
+ hidden_states = attn.to_out[1](hidden_states)
613
+
614
+ if input_ndim == 4:
615
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
616
+ batch_size, channel, height, width
617
+ )
618
+
619
+ if attn.residual_connection:
620
+ hidden_states = hidden_states + residual
621
+
622
+ hidden_states = hidden_states / attn.rescale_output_factor
623
+
624
+ return hidden_states
src/models/autoencoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder_kl_triposg import TripoSGVAEModel
src/models/autoencoders/autoencoder_kl_triposg.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.autoencoders.vae import DecoderOutput
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.normalization import FP32LayerNorm, LayerNorm
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.accelerate_utils import apply_forward_hook
14
+ from einops import repeat
15
+ from torch_cluster import fps
16
+ from tqdm import tqdm
17
+
18
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripo2AttnProcessor2_0
19
+ from ..embeddings import FrequencyPositionalEmbedding
20
+ from ..transformers.partcrafter_transformer import DiTBlock
21
+ from .vae import DiagonalGaussianDistribution
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class TripoSGEncoder(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ dim: int = 512,
31
+ num_attention_heads: int = 8,
32
+ num_layers: int = 8,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
37
+
38
+ self.blocks = nn.ModuleList(
39
+ [
40
+ DiTBlock(
41
+ dim=dim,
42
+ num_attention_heads=num_attention_heads,
43
+ use_self_attention=False,
44
+ use_cross_attention=True,
45
+ cross_attention_dim=dim,
46
+ cross_attention_norm_type="layer_norm",
47
+ activation_fn="gelu",
48
+ norm_type="fp32_layer_norm",
49
+ norm_eps=1e-5,
50
+ qk_norm=False,
51
+ qkv_bias=False,
52
+ ) # cross attention
53
+ ]
54
+ + [
55
+ DiTBlock(
56
+ dim=dim,
57
+ num_attention_heads=num_attention_heads,
58
+ use_self_attention=True,
59
+ self_attention_norm_type="fp32_layer_norm",
60
+ use_cross_attention=False,
61
+ activation_fn="gelu",
62
+ norm_type="fp32_layer_norm",
63
+ norm_eps=1e-5,
64
+ qk_norm=False,
65
+ qkv_bias=False,
66
+ )
67
+ for _ in range(num_layers) # self attention
68
+ ]
69
+ )
70
+
71
+ self.norm_out = LayerNorm(dim)
72
+
73
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
74
+ hidden_states = self.proj_in(sample_1)
75
+ encoder_hidden_states = self.proj_in(sample_2)
76
+
77
+ for layer, block in enumerate(self.blocks):
78
+ if layer == 0:
79
+ hidden_states = block(
80
+ hidden_states, encoder_hidden_states=encoder_hidden_states
81
+ )
82
+ else:
83
+ hidden_states = block(hidden_states)
84
+
85
+ hidden_states = self.norm_out(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+
90
+ class TripoSGDecoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int = 3,
94
+ out_channels: int = 1,
95
+ dim: int = 512,
96
+ num_attention_heads: int = 8,
97
+ num_layers: int = 16,
98
+ grad_type: str = "analytical",
99
+ grad_interval: float = 0.001,
100
+ ):
101
+ super().__init__()
102
+
103
+ if grad_type not in ["numerical", "analytical"]:
104
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
105
+ self.grad_type = grad_type
106
+ self.grad_interval = grad_interval
107
+
108
+ self.blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim=dim,
112
+ num_attention_heads=num_attention_heads,
113
+ use_self_attention=True,
114
+ self_attention_norm_type="fp32_layer_norm",
115
+ use_cross_attention=False,
116
+ activation_fn="gelu",
117
+ norm_type="fp32_layer_norm",
118
+ norm_eps=1e-5,
119
+ qk_norm=False,
120
+ qkv_bias=False,
121
+ )
122
+ for _ in range(num_layers) # self attention
123
+ ]
124
+ + [
125
+ DiTBlock(
126
+ dim=dim,
127
+ num_attention_heads=num_attention_heads,
128
+ use_self_attention=False,
129
+ use_cross_attention=True,
130
+ cross_attention_dim=dim,
131
+ cross_attention_norm_type="layer_norm",
132
+ activation_fn="gelu",
133
+ norm_type="fp32_layer_norm",
134
+ norm_eps=1e-5,
135
+ qk_norm=False,
136
+ qkv_bias=False,
137
+ ) # cross attention
138
+ ]
139
+ )
140
+
141
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
142
+
143
+ self.norm_out = LayerNorm(dim)
144
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
145
+
146
+ def set_topk(self, topk):
147
+ self.blocks[-1].set_topk(topk)
148
+
149
+ def set_flash_processor(self, processor):
150
+ self.blocks[-1].set_flash_processor(processor)
151
+
152
+ def query_geometry(
153
+ self,
154
+ model_fn: callable,
155
+ queries: torch.Tensor,
156
+ sample: torch.Tensor,
157
+ grad: bool = False,
158
+ ):
159
+ logits = model_fn(queries, sample)
160
+ if grad:
161
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
162
+ if self.grad_type == "numerical":
163
+ interval = self.grad_interval
164
+ grad_value = []
165
+ for offset in [
166
+ (interval, 0, 0),
167
+ (0, interval, 0),
168
+ (0, 0, interval),
169
+ ]:
170
+ offset_tensor = torch.tensor(offset, device=queries.device)[
171
+ None, :
172
+ ]
173
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
174
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
175
+ grad_value.append((res_p - res_n) / (2 * interval))
176
+ grad_value = torch.stack(grad_value, dim=-1)
177
+ else:
178
+ queries_d = torch.clone(queries)
179
+ queries_d.requires_grad = True
180
+ with torch.enable_grad():
181
+ res_d = model_fn(queries_d, sample)
182
+ grad_value = torch.autograd.grad(
183
+ res_d,
184
+ [queries_d],
185
+ grad_outputs=torch.ones_like(res_d),
186
+ create_graph=self.training,
187
+ )[0]
188
+ else:
189
+ grad_value = None
190
+
191
+ return logits, grad_value
192
+
193
+ def forward(
194
+ self,
195
+ sample: torch.Tensor,
196
+ queries: torch.Tensor,
197
+ kv_cache: Optional[torch.Tensor] = None,
198
+ ):
199
+ if kv_cache is None:
200
+ hidden_states = sample
201
+ for _, block in enumerate(self.blocks[:-1]):
202
+ hidden_states = block(hidden_states)
203
+ kv_cache = hidden_states
204
+
205
+ # query grid logits by cross attention
206
+ def query_fn(q, kv):
207
+ q = self.proj_query(q)
208
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
209
+ return self.proj_out(self.norm_out(l))
210
+
211
+ logits, grad = self.query_geometry(
212
+ query_fn, queries, kv_cache, grad=self.training
213
+ )
214
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
215
+
216
+ return logits, kv_cache
217
+
218
+
219
+ class TripoSGVAEModel(ModelMixin, ConfigMixin):
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 3, # NOTE xyz instead of feature dim
224
+ latent_channels: int = 64,
225
+ num_attention_heads: int = 8,
226
+ width_encoder: int = 512,
227
+ width_decoder: int = 1024,
228
+ num_layers_encoder: int = 8,
229
+ num_layers_decoder: int = 16,
230
+ embedding_type: str = "frequency",
231
+ embed_frequency: int = 8,
232
+ embed_include_pi: bool = False,
233
+ ):
234
+ super().__init__()
235
+
236
+ self.out_channels = 1
237
+
238
+ if embedding_type == "frequency":
239
+ self.embedder = FrequencyPositionalEmbedding(
240
+ num_freqs=embed_frequency,
241
+ logspace=True,
242
+ input_dim=in_channels,
243
+ include_pi=embed_include_pi,
244
+ )
245
+ else:
246
+ raise NotImplementedError(
247
+ f"Embedding type {embedding_type} is not supported."
248
+ )
249
+
250
+ self.encoder = TripoSGEncoder(
251
+ in_channels=in_channels + self.embedder.out_dim,
252
+ dim=width_encoder,
253
+ num_attention_heads=num_attention_heads,
254
+ num_layers=num_layers_encoder,
255
+ )
256
+ self.decoder = TripoSGDecoder(
257
+ in_channels=self.embedder.out_dim,
258
+ out_channels=self.out_channels,
259
+ dim=width_decoder,
260
+ num_attention_heads=num_attention_heads,
261
+ num_layers=num_layers_decoder,
262
+ )
263
+
264
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
265
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
266
+
267
+ self.use_slicing = False
268
+ self.slicing_length = 1
269
+
270
+ def set_flash_decoder(self):
271
+ self.decoder.set_flash_processor(FlashTripo2AttnProcessor2_0())
272
+
273
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
274
+ def fuse_qkv_projections(self):
275
+ """
276
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
277
+ are fused. For cross-attention modules, key and value projection matrices are fused.
278
+
279
+ <Tip warning={true}>
280
+
281
+ This API is 🧪 experimental.
282
+
283
+ </Tip>
284
+ """
285
+ self.original_attn_processors = None
286
+
287
+ for _, attn_processor in self.attn_processors.items():
288
+ if "Added" in str(attn_processor.__class__.__name__):
289
+ raise ValueError(
290
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
291
+ )
292
+
293
+ self.original_attn_processors = self.attn_processors
294
+
295
+ for module in self.modules():
296
+ if isinstance(module, Attention):
297
+ module.fuse_projections(fuse=True)
298
+
299
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
300
+
301
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
302
+ def unfuse_qkv_projections(self):
303
+ """Disables the fused QKV projection if enabled.
304
+
305
+ <Tip warning={true}>
306
+
307
+ This API is 🧪 experimental.
308
+
309
+ </Tip>
310
+
311
+ """
312
+ if self.original_attn_processors is not None:
313
+ self.set_attn_processor(self.original_attn_processors)
314
+
315
+ @property
316
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
317
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
318
+ r"""
319
+ Returns:
320
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
321
+ indexed by its weight name.
322
+ """
323
+ # set recursively
324
+ processors = {}
325
+
326
+ def fn_recursive_add_processors(
327
+ name: str,
328
+ module: torch.nn.Module,
329
+ processors: Dict[str, AttentionProcessor],
330
+ ):
331
+ if hasattr(module, "get_processor"):
332
+ processors[f"{name}.processor"] = module.get_processor()
333
+
334
+ for sub_name, child in module.named_children():
335
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
336
+
337
+ return processors
338
+
339
+ for name, module in self.named_children():
340
+ fn_recursive_add_processors(name, module, processors)
341
+
342
+ return processors
343
+
344
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345
+ def set_attn_processor(
346
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
347
+ ):
348
+ r"""
349
+ Sets the attention processor to use to compute attention.
350
+
351
+ Parameters:
352
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
353
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
354
+ for **all** `Attention` layers.
355
+
356
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
357
+ processor. This is strongly recommended when setting trainable attention processors.
358
+
359
+ """
360
+ count = len(self.attn_processors.keys())
361
+
362
+ if isinstance(processor, dict) and len(processor) != count:
363
+ raise ValueError(
364
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
365
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
366
+ )
367
+
368
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
369
+ if hasattr(module, "set_processor"):
370
+ if not isinstance(processor, dict):
371
+ module.set_processor(processor)
372
+ else:
373
+ module.set_processor(processor.pop(f"{name}.processor"))
374
+
375
+ for sub_name, child in module.named_children():
376
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377
+
378
+ for name, module in self.named_children():
379
+ fn_recursive_attn_processor(name, module, processor)
380
+
381
+ def set_default_attn_processor(self):
382
+ """
383
+ Disables custom attention processors and sets the default attention implementation.
384
+ """
385
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
386
+
387
+ def enable_slicing(self, slicing_length: int = 1) -> None:
388
+ r"""
389
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
390
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
391
+ """
392
+ self.use_slicing = True
393
+ self.slicing_length = slicing_length
394
+
395
+ def disable_slicing(self) -> None:
396
+ r"""
397
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
398
+ decoding in one step.
399
+ """
400
+ self.use_slicing = False
401
+
402
+ def _sample_features(
403
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
404
+ ):
405
+ """
406
+ Sample points from features of the input point cloud.
407
+
408
+ Args:
409
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
410
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
411
+ seed (Optional[int], optional): The random seed. Defaults to None.
412
+ """
413
+ rng = np.random.default_rng(seed)
414
+ indices = rng.choice(
415
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
416
+ )
417
+ selected_points = x[:, indices]
418
+
419
+ batch_size, num_points, num_channels = selected_points.shape
420
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
421
+ batch_indices = (
422
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
423
+ )
424
+
425
+ # fps sampling
426
+ sampling_ratio = 1.0 / 4
427
+ sampled_indices = fps(
428
+ flattened_points[:, :3],
429
+ batch_indices,
430
+ ratio=sampling_ratio,
431
+ random_start=self.training,
432
+ )
433
+ sampled_points = flattened_points[sampled_indices].view(
434
+ batch_size, -1, num_channels
435
+ )
436
+
437
+ return sampled_points
438
+
439
+ def _encode(
440
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
441
+ ):
442
+ position_channels = self.config.in_channels
443
+ positions, features = x[..., :position_channels], x[..., position_channels:]
444
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
445
+
446
+ sampled_x = self._sample_features(x, num_tokens, seed)
447
+ positions, features = (
448
+ sampled_x[..., :position_channels],
449
+ sampled_x[..., position_channels:],
450
+ )
451
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
452
+
453
+ x = self.encoder(x_q, x_kv)
454
+
455
+ x = self.quant(x)
456
+
457
+ return x
458
+
459
+ @apply_forward_hook
460
+ def encode(
461
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
462
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
463
+ """
464
+ Encode a batch of point features into latents.
465
+ """
466
+ if self.use_slicing and x.shape[0] > 1:
467
+ encoded_slices = [
468
+ self._encode(x_slice, **kwargs)
469
+ for x_slice in x.split(self.slicing_length)
470
+ ]
471
+ h = torch.cat(encoded_slices)
472
+ else:
473
+ h = self._encode(x, **kwargs)
474
+
475
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
476
+
477
+ if not return_dict:
478
+ return (posterior,)
479
+ return AutoencoderKLOutput(latent_dist=posterior)
480
+
481
+ def _decode(
482
+ self,
483
+ z: torch.Tensor,
484
+ sampled_points: torch.Tensor,
485
+ num_chunks: int = 50000,
486
+ to_cpu: bool = False,
487
+ return_dict: bool = True,
488
+ ) -> Union[DecoderOutput, torch.Tensor]:
489
+ xyz_samples = sampled_points
490
+
491
+ z = self.post_quant(z)
492
+
493
+ num_points = xyz_samples.shape[1]
494
+ kv_cache = None
495
+ dec = []
496
+
497
+ for i in range(0, num_points, num_chunks):
498
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
499
+ queries = self.embedder(queries)
500
+
501
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
502
+ dec.append(z_ if not to_cpu else z_.cpu())
503
+
504
+ z = torch.cat(dec, dim=1)
505
+
506
+ if not return_dict:
507
+ return (z,)
508
+
509
+ return DecoderOutput(sample=z)
510
+
511
+ @apply_forward_hook
512
+ def decode(
513
+ self,
514
+ z: torch.Tensor,
515
+ sampled_points: torch.Tensor,
516
+ return_dict: bool = True,
517
+ **kwargs,
518
+ ) -> Union[DecoderOutput, torch.Tensor]:
519
+ if self.use_slicing and z.shape[0] > 1:
520
+ decoded_slices = [
521
+ self._decode(z_slice, p_slice, **kwargs).sample
522
+ for z_slice, p_slice in zip(
523
+ z.split(self.slicing_length),
524
+ sampled_points.split(self.slicing_length),
525
+ )
526
+ ]
527
+ decoded = torch.cat(decoded_slices)
528
+ else:
529
+ decoded = self._decode(z, sampled_points, **kwargs).sample
530
+
531
+ if not return_dict:
532
+ return (decoded,)
533
+ return DecoderOutput(sample=decoded)
534
+
535
+ def forward(self, x: torch.Tensor):
536
+ pass
src/models/autoencoders/vae.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(
10
+ self,
11
+ parameters: torch.Tensor,
12
+ deterministic: bool = False,
13
+ feature_dim: int = 1,
14
+ ):
15
+ self.parameters = parameters
16
+ self.feature_dim = feature_dim
17
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
18
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
19
+ self.deterministic = deterministic
20
+ self.std = torch.exp(0.5 * self.logvar)
21
+ self.var = torch.exp(self.logvar)
22
+ if self.deterministic:
23
+ self.var = self.std = torch.zeros_like(
24
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
25
+ )
26
+
27
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
28
+ # make sure sample is on the same device as the parameters and has same dtype
29
+ sample = randn_tensor(
30
+ self.mean.shape,
31
+ generator=generator,
32
+ device=self.parameters.device,
33
+ dtype=self.parameters.dtype,
34
+ )
35
+ x = self.mean + self.std * sample
36
+ return x
37
+
38
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
39
+ if self.deterministic:
40
+ return torch.Tensor([0.0])
41
+ else:
42
+ if other is None:
43
+ return 0.5 * torch.sum(
44
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
45
+ dim=[1, 2, 3],
46
+ )
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var
51
+ - 1.0
52
+ - self.logvar
53
+ + other.logvar,
54
+ dim=[1, 2, 3],
55
+ )
56
+
57
+ def nll(
58
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
59
+ ) -> torch.Tensor:
60
+ if self.deterministic:
61
+ return torch.Tensor([0.0])
62
+ logtwopi = np.log(2.0 * np.pi)
63
+ return 0.5 * torch.sum(
64
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
65
+ dim=dims,
66
+ )
67
+
68
+ def mode(self) -> torch.Tensor:
69
+ return self.mean
src/models/briarmbg.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and Copyright Notice:
3
+ This code is from briaai/RMBG-1.4
4
+ Original repository: https://huggingface.co/briaai/RMBG-1.4
5
+ Copyright belongs to briaai
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ class REBNCONV(nn.Module):
14
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
15
+ super(REBNCONV,self).__init__()
16
+
17
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self,x):
22
+
23
+ hx = x
24
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
25
+
26
+ return xout
27
+
28
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
29
+ def _upsample_like(src,tar):
30
+
31
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
32
+
33
+ return src
34
+
35
+
36
+ ### RSU-7 ###
37
+ class RSU7(nn.Module):
38
+
39
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
40
+ super(RSU7,self).__init__()
41
+
42
+ self.in_ch = in_ch
43
+ self.mid_ch = mid_ch
44
+ self.out_ch = out_ch
45
+
46
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
47
+
48
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
49
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
53
+
54
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
55
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
56
+
57
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
58
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
59
+
60
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
61
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
62
+
63
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
64
+
65
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
66
+
67
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
68
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
69
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
70
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
71
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
72
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
73
+
74
+ def forward(self,x):
75
+ b, c, h, w = x.shape
76
+
77
+ hx = x
78
+ hxin = self.rebnconvin(hx)
79
+
80
+ hx1 = self.rebnconv1(hxin)
81
+ hx = self.pool1(hx1)
82
+
83
+ hx2 = self.rebnconv2(hx)
84
+ hx = self.pool2(hx2)
85
+
86
+ hx3 = self.rebnconv3(hx)
87
+ hx = self.pool3(hx3)
88
+
89
+ hx4 = self.rebnconv4(hx)
90
+ hx = self.pool4(hx4)
91
+
92
+ hx5 = self.rebnconv5(hx)
93
+ hx = self.pool5(hx5)
94
+
95
+ hx6 = self.rebnconv6(hx)
96
+
97
+ hx7 = self.rebnconv7(hx6)
98
+
99
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
100
+ hx6dup = _upsample_like(hx6d,hx5)
101
+
102
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
103
+ hx5dup = _upsample_like(hx5d,hx4)
104
+
105
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
106
+ hx4dup = _upsample_like(hx4d,hx3)
107
+
108
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
109
+ hx3dup = _upsample_like(hx3d,hx2)
110
+
111
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
112
+ hx2dup = _upsample_like(hx2d,hx1)
113
+
114
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
115
+
116
+ return hx1d + hxin
117
+
118
+
119
+ ### RSU-6 ###
120
+ class RSU6(nn.Module):
121
+
122
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
123
+ super(RSU6,self).__init__()
124
+
125
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
126
+
127
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
128
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
129
+
130
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
131
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
132
+
133
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
134
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
135
+
136
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
137
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
138
+
139
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
140
+
141
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
142
+
143
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
144
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
145
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
146
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
147
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
148
+
149
+ def forward(self,x):
150
+
151
+ hx = x
152
+
153
+ hxin = self.rebnconvin(hx)
154
+
155
+ hx1 = self.rebnconv1(hxin)
156
+ hx = self.pool1(hx1)
157
+
158
+ hx2 = self.rebnconv2(hx)
159
+ hx = self.pool2(hx2)
160
+
161
+ hx3 = self.rebnconv3(hx)
162
+ hx = self.pool3(hx3)
163
+
164
+ hx4 = self.rebnconv4(hx)
165
+ hx = self.pool4(hx4)
166
+
167
+ hx5 = self.rebnconv5(hx)
168
+
169
+ hx6 = self.rebnconv6(hx5)
170
+
171
+
172
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
173
+ hx5dup = _upsample_like(hx5d,hx4)
174
+
175
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
176
+ hx4dup = _upsample_like(hx4d,hx3)
177
+
178
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
179
+ hx3dup = _upsample_like(hx3d,hx2)
180
+
181
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
182
+ hx2dup = _upsample_like(hx2d,hx1)
183
+
184
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
185
+
186
+ return hx1d + hxin
187
+
188
+ ### RSU-5 ###
189
+ class RSU5(nn.Module):
190
+
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5,self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
213
+
214
+ def forward(self,x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
234
+ hx4dup = _upsample_like(hx4d,hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
237
+ hx3dup = _upsample_like(hx3d,hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
240
+ hx2dup = _upsample_like(hx2d,hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
243
+
244
+ return hx1d + hxin
245
+
246
+ ### RSU-4 ###
247
+ class RSU4(nn.Module):
248
+
249
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
250
+ super(RSU4,self).__init__()
251
+
252
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
253
+
254
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
255
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
256
+
257
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
258
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
259
+
260
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
261
+
262
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
263
+
264
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
265
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
266
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
267
+
268
+ def forward(self,x):
269
+
270
+ hx = x
271
+
272
+ hxin = self.rebnconvin(hx)
273
+
274
+ hx1 = self.rebnconv1(hxin)
275
+ hx = self.pool1(hx1)
276
+
277
+ hx2 = self.rebnconv2(hx)
278
+ hx = self.pool2(hx2)
279
+
280
+ hx3 = self.rebnconv3(hx)
281
+
282
+ hx4 = self.rebnconv4(hx3)
283
+
284
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
285
+ hx3dup = _upsample_like(hx3d,hx2)
286
+
287
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
288
+ hx2dup = _upsample_like(hx2d,hx1)
289
+
290
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
291
+
292
+ return hx1d + hxin
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F,self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
311
+
312
+ def forward(self,x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(self, in_ch=3,
333
+ out_ch=1,
334
+ kernel_size=3,
335
+ stride=1,
336
+ padding=1,
337
+ dilation=1,
338
+ groups=1):
339
+ super(myrebnconv,self).__init__()
340
+
341
+ self.conv = nn.Conv2d(in_ch,
342
+ out_ch,
343
+ kernel_size=kernel_size,
344
+ stride=stride,
345
+ padding=padding,
346
+ dilation=dilation,
347
+ groups=groups)
348
+ self.bn = nn.BatchNorm2d(out_ch)
349
+ self.rl = nn.ReLU(inplace=True)
350
+
351
+ def forward(self,x):
352
+ return self.rl(self.bn(self.conv(x)))
353
+
354
+
355
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
356
+
357
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
358
+ super(BriaRMBG,self).__init__()
359
+ in_ch=config["in_ch"]
360
+ out_ch=config["out_ch"]
361
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
362
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage1 = RSU7(64,32,64)
365
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage2 = RSU6(64,32,128)
368
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage3 = RSU5(128,64,256)
371
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
372
+
373
+ self.stage4 = RSU4(256,128,512)
374
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
375
+
376
+ self.stage5 = RSU4F(512,256,512)
377
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
378
+
379
+ self.stage6 = RSU4F(512,256,512)
380
+
381
+ # decoder
382
+ self.stage5d = RSU4F(1024,256,512)
383
+ self.stage4d = RSU4(1024,128,256)
384
+ self.stage3d = RSU5(512,64,128)
385
+ self.stage2d = RSU6(256,32,64)
386
+ self.stage1d = RSU7(128,16,64)
387
+
388
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
389
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
390
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
391
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
392
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
393
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
394
+
395
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
396
+
397
+ def forward(self,x):
398
+
399
+ hx = x
400
+
401
+ hxin = self.conv_in(hx)
402
+ #hx = self.pool_in(hxin)
403
+
404
+ #stage 1
405
+ hx1 = self.stage1(hxin)
406
+ hx = self.pool12(hx1)
407
+
408
+ #stage 2
409
+ hx2 = self.stage2(hx)
410
+ hx = self.pool23(hx2)
411
+
412
+ #stage 3
413
+ hx3 = self.stage3(hx)
414
+ hx = self.pool34(hx3)
415
+
416
+ #stage 4
417
+ hx4 = self.stage4(hx)
418
+ hx = self.pool45(hx4)
419
+
420
+ #stage 5
421
+ hx5 = self.stage5(hx)
422
+ hx = self.pool56(hx5)
423
+
424
+ #stage 6
425
+ hx6 = self.stage6(hx)
426
+ hx6up = _upsample_like(hx6,hx5)
427
+
428
+ #-------------------- decoder --------------------
429
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
430
+ hx5dup = _upsample_like(hx5d,hx4)
431
+
432
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
433
+ hx4dup = _upsample_like(hx4d,hx3)
434
+
435
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
436
+ hx3dup = _upsample_like(hx3d,hx2)
437
+
438
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
439
+ hx2dup = _upsample_like(hx2d,hx1)
440
+
441
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
442
+
443
+
444
+ #side output
445
+ d1 = self.side1(hx1d)
446
+ d1 = _upsample_like(d1,x)
447
+
448
+ d2 = self.side2(hx2d)
449
+ d2 = _upsample_like(d2,x)
450
+
451
+ d3 = self.side3(hx3d)
452
+ d3 = _upsample_like(d3,x)
453
+
454
+ d4 = self.side4(hx4d)
455
+ d4 = _upsample_like(d4,x)
456
+
457
+ d5 = self.side5(hx5d)
458
+ d5 = _upsample_like(d5,x)
459
+
460
+ d6 = self.side6(hx6)
461
+ d6 = _upsample_like(d6,x)
462
+
463
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
464
+
src/models/embeddings.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FrequencyPositionalEmbedding(nn.Module):
6
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
7
+ each feature dimension of `x[..., i]` into:
8
+ [
9
+ sin(x[..., i]),
10
+ sin(f_1*x[..., i]),
11
+ sin(f_2*x[..., i]),
12
+ ...
13
+ sin(f_N * x[..., i]),
14
+ cos(x[..., i]),
15
+ cos(f_1*x[..., i]),
16
+ cos(f_2*x[..., i]),
17
+ ...
18
+ cos(f_N * x[..., i]),
19
+ x[..., i] # only present if include_input is True.
20
+ ], here f_i is the frequency.
21
+
22
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
23
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
24
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
25
+
26
+ Args:
27
+ num_freqs (int): the number of frequencies, default is 6;
28
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
29
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
30
+ input_dim (int): the input dimension, default is 3;
31
+ include_input (bool): include the input tensor or not, default is True.
32
+
33
+ Attributes:
34
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
36
+
37
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
38
+ otherwise, it is input_dim * num_freqs * 2.
39
+
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_freqs: int = 6,
45
+ logspace: bool = True,
46
+ input_dim: int = 3,
47
+ include_input: bool = True,
48
+ include_pi: bool = True,
49
+ ) -> None:
50
+ """The initialization"""
51
+
52
+ super().__init__()
53
+
54
+ if logspace:
55
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
56
+ else:
57
+ frequencies = torch.linspace(
58
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
59
+ )
60
+
61
+ if include_pi:
62
+ frequencies *= torch.pi
63
+
64
+ self.register_buffer("frequencies", frequencies, persistent=False)
65
+ self.include_input = include_input
66
+ self.num_freqs = num_freqs
67
+
68
+ self.out_dim = self.get_dims(input_dim)
69
+
70
+ def get_dims(self, input_dim):
71
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
72
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
73
+
74
+ return out_dim
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """Forward process.
78
+
79
+ Args:
80
+ x: tensor of shape [..., dim]
81
+
82
+ Returns:
83
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
84
+ where temp is 1 if include_input is True and 0 otherwise.
85
+ """
86
+
87
+ if self.num_freqs > 0:
88
+ embed = (x[..., None].contiguous() * self.frequencies).view(
89
+ *x.shape[:-1], -1
90
+ )
91
+ if self.include_input:
92
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
93
+ else:
94
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
95
+ else:
96
+ return x
src/models/transformers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .partcrafter_transformer import PartCrafterDiTModel
src/models/transformers/modeling_outputs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Transformer1DModelOutput:
8
+ sample: torch.FloatTensor
src/models/transformers/partcrafter_transformer.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Yuchen Lin
2
+
3
+ # This code is based on TripoSG (https://github.com/VAST-AI-Research/TripoSG). Below is the statement from the original repository:
4
+
5
+ # This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
6
+ # which is licensed under the Tencent Hunyuan Community License Agreement.
7
+ # Portions of this code are copied or adapted from HunyuanDiT.
8
+ # See the original license below:
9
+
10
+ # ---- Start of Tencent Hunyuan Community License Agreement ----
11
+
12
+ # TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
13
+ # Tencent Hunyuan DiT Release Date: 14 May 2024
14
+ # THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
15
+ # By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
16
+ # 1. DEFINITIONS.
17
+ # a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
18
+ # b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
19
+ # c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
20
+ # d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
21
+ # e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
22
+ # f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
23
+ # g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
24
+ # h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
25
+ # i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
26
+ # j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
27
+ # k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
28
+ # l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
29
+ # m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
30
+ # n. “including” shall mean including but not limited to.
31
+ # 2. GRANT OF RIGHTS.
32
+ # We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
33
+ # 3. DISTRIBUTION.
34
+ # You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
35
+ # a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
36
+ # b. You must cause any modified files to carry prominent notices stating that You changed the files;
37
+ # c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
38
+ # d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
39
+ # You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
40
+ # 4. ADDITIONAL COMMERCIAL TERMS.
41
+ # If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
42
+ # 5. RULES OF USE.
43
+ # a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
44
+ # b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
45
+ # c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
46
+ # 6. INTELLECTUAL PROPERTY.
47
+ # a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
48
+ # b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
49
+ # c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
50
+ # d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
51
+ # 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
52
+ # a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
53
+ # b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
54
+ # c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
55
+ # 8. SURVIVAL AND TERMINATION.
56
+ # a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
57
+ # b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
58
+ # 9. GOVERNING LAW AND JURISDICTION.
59
+ # a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
60
+ # b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
61
+ #
62
+ # EXHIBIT A
63
+ # ACCEPTABLE USE POLICY
64
+
65
+ # Tencent reserves the right to update this Acceptable Use Policy from time to time.
66
+ # Last modified: [insert date]
67
+
68
+ # Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
69
+ # 1. Outside the Territory;
70
+ # 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
71
+ # 3. To harm Yourself or others;
72
+ # 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
73
+ # 5. To override or circumvent the safety guardrails and safeguards We have put in place;
74
+ # 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
75
+ # 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
76
+ # 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
77
+ # 9. To intentionally defame, disparage or otherwise harass others;
78
+ # 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
79
+ # 11. To generate or disseminate personal identifiable information with the purpose of harming others;
80
+ # 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
81
+ # 13. To impersonate another individual without consent, authorization, or legal right;
82
+ # 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
83
+ # 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
84
+ # 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
85
+ # 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
86
+ # 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
87
+ # 19. For military purposes;
88
+ # 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
89
+
90
+ # ---- End of Tencent Hunyuan Community License Agreement ----
91
+
92
+ # Please note that the use of this code is subject to the terms and conditions
93
+ # of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
94
+
95
+ from typing import *
96
+
97
+ import torch
98
+ import torch.utils.checkpoint
99
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
100
+ from diffusers.loaders import PeftAdapterMixin
101
+ from diffusers.models.attention import FeedForward
102
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
103
+ from diffusers.models.embeddings import (
104
+ GaussianFourierProjection,
105
+ TimestepEmbedding,
106
+ Timesteps,
107
+ )
108
+ from diffusers.models.modeling_utils import ModelMixin
109
+ from diffusers.models.normalization import (
110
+ AdaLayerNormContinuous,
111
+ FP32LayerNorm,
112
+ LayerNorm,
113
+ )
114
+ from diffusers.utils import (
115
+ USE_PEFT_BACKEND,
116
+ is_torch_version,
117
+ logging,
118
+ scale_lora_layers,
119
+ unscale_lora_layers,
120
+ )
121
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
122
+ from torch import nn
123
+
124
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, PartCrafterAttnProcessor
125
+ from .modeling_outputs import Transformer1DModelOutput
126
+
127
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
128
+
129
+
130
+ @maybe_allow_in_graph
131
+ class DiTBlock(nn.Module):
132
+ r"""
133
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
134
+ QKNorm
135
+
136
+ Parameters:
137
+ dim (`int`):
138
+ The number of channels in the input and output.
139
+ num_attention_heads (`int`):
140
+ The number of headsto use for multi-head attention.
141
+ cross_attention_dim (`int`,*optional*):
142
+ The size of the encoder_hidden_states vector for cross attention.
143
+ dropout(`float`, *optional*, defaults to 0.0):
144
+ The dropout probability to use.
145
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
146
+ Activation function to be used in feed-forward. .
147
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
148
+ Whether to use learnable elementwise affine parameters for normalization.
149
+ norm_eps (`float`, *optional*, defaults to 1e-6):
150
+ A small constant added to the denominator in normalization layers to prevent division by zero.
151
+ final_dropout (`bool` *optional*, defaults to False):
152
+ Whether to apply a final dropout after the last feed-forward layer.
153
+ ff_inner_dim (`int`, *optional*):
154
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
155
+ ff_bias (`bool`, *optional*, defaults to `True`):
156
+ Whether to use bias in the feed-forward block.
157
+ skip (`bool`, *optional*, defaults to `False`):
158
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
159
+ qk_norm (`bool`, *optional*, defaults to `True`):
160
+ Whether to use normalization in QK calculation. Defaults to `True`.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ dim: int,
166
+ num_attention_heads: int,
167
+ use_self_attention: bool = True,
168
+ self_attention_norm_type: Optional[str] = None,
169
+ use_cross_attention: bool = True, # ada layer norm
170
+ cross_attention_dim: Optional[int] = None,
171
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
172
+ dropout=0.0,
173
+ activation_fn: str = "gelu",
174
+ norm_type: str = "fp32_layer_norm", # TODO
175
+ norm_elementwise_affine: bool = True,
176
+ norm_eps: float = 1e-5,
177
+ final_dropout: bool = False,
178
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
179
+ ff_bias: bool = True,
180
+ skip: bool = False,
181
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
182
+ skip_norm_last: bool = False, # this is an error
183
+ qk_norm: bool = True,
184
+ qkv_bias: bool = True,
185
+ ):
186
+ super().__init__()
187
+
188
+ self.use_self_attention = use_self_attention
189
+ self.use_cross_attention = use_cross_attention
190
+ self.skip_concat_front = skip_concat_front
191
+ self.skip_norm_last = skip_norm_last
192
+ # Define 3 blocks. Each block has its own normalization layer.
193
+ # NOTE: when new version comes, check norm2 and norm 3
194
+ # 1. Self-Attn
195
+ if use_self_attention:
196
+ if (
197
+ self_attention_norm_type == "fp32_layer_norm"
198
+ or self_attention_norm_type is None
199
+ ):
200
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
201
+ else:
202
+ raise NotImplementedError
203
+
204
+ self.attn1 = Attention(
205
+ query_dim=dim,
206
+ cross_attention_dim=None,
207
+ dim_head=dim // num_attention_heads,
208
+ heads=num_attention_heads,
209
+ qk_norm="rms_norm" if qk_norm else None,
210
+ eps=1e-6,
211
+ bias=qkv_bias,
212
+ processor=TripoSGAttnProcessor2_0(),
213
+ )
214
+
215
+ # 2. Cross-Attn
216
+ if use_cross_attention:
217
+ assert cross_attention_dim is not None
218
+
219
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
220
+
221
+ self.attn2 = Attention(
222
+ query_dim=dim,
223
+ cross_attention_dim=cross_attention_dim,
224
+ dim_head=dim // num_attention_heads,
225
+ heads=num_attention_heads,
226
+ qk_norm="rms_norm" if qk_norm else None,
227
+ cross_attention_norm=cross_attention_norm_type,
228
+ eps=1e-6,
229
+ bias=qkv_bias,
230
+ processor=TripoSGAttnProcessor2_0(),
231
+ )
232
+
233
+ # 3. Feed-forward
234
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
235
+
236
+ self.ff = FeedForward(
237
+ dim,
238
+ dropout=dropout, ### 0.0
239
+ activation_fn=activation_fn, ### approx GeLU
240
+ final_dropout=final_dropout, ### 0.0
241
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
242
+ bias=ff_bias,
243
+ )
244
+
245
+ # 4. Skip Connection
246
+ if skip:
247
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
248
+ self.skip_linear = nn.Linear(2 * dim, dim)
249
+ else:
250
+ self.skip_linear = None
251
+
252
+ # let chunk size default to None
253
+ self._chunk_size = None
254
+ self._chunk_dim = 0
255
+
256
+ def set_topk(self, topk):
257
+ self.flash_processor.topk = topk
258
+
259
+ def set_flash_processor(self, flash_processor):
260
+ self.flash_processor = flash_processor
261
+ self.attn2.processor = self.flash_processor
262
+
263
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
264
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
265
+ # Sets chunk feed-forward
266
+ self._chunk_size = chunk_size
267
+ self._chunk_dim = dim
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ encoder_hidden_states: Optional[torch.Tensor] = None,
273
+ temb: Optional[torch.Tensor] = None,
274
+ image_rotary_emb: Optional[torch.Tensor] = None,
275
+ skip: Optional[torch.Tensor] = None,
276
+ attention_kwargs: Optional[Dict[str, Any]] = None,
277
+ ) -> torch.Tensor:
278
+ # Prepare attention kwargs
279
+ attention_kwargs = attention_kwargs or {}
280
+
281
+ # Notice that normalization is always applied before the real computation in the following blocks.
282
+ # 0. Long Skip Connection
283
+ if self.skip_linear is not None:
284
+ cat = torch.cat(
285
+ (
286
+ [skip, hidden_states]
287
+ if self.skip_concat_front
288
+ else [hidden_states, skip]
289
+ ),
290
+ dim=-1,
291
+ )
292
+ if self.skip_norm_last:
293
+ # don't do this
294
+ hidden_states = self.skip_linear(cat)
295
+ hidden_states = self.skip_norm(hidden_states)
296
+ else:
297
+ cat = self.skip_norm(cat)
298
+ hidden_states = self.skip_linear(cat)
299
+
300
+ # 1. Self-Attention
301
+ if self.use_self_attention:
302
+ norm_hidden_states = self.norm1(hidden_states)
303
+ attn_output = self.attn1(
304
+ norm_hidden_states,
305
+ image_rotary_emb=image_rotary_emb,
306
+ **attention_kwargs,
307
+ )
308
+ hidden_states = hidden_states + attn_output
309
+
310
+ # 2. Cross-Attention
311
+ if self.use_cross_attention:
312
+ hidden_states = hidden_states + self.attn2(
313
+ self.norm2(hidden_states),
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ image_rotary_emb=image_rotary_emb,
316
+ **attention_kwargs,
317
+ )
318
+
319
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
320
+ mlp_inputs = self.norm3(hidden_states)
321
+ hidden_states = hidden_states + self.ff(mlp_inputs)
322
+
323
+ return hidden_states
324
+
325
+ # Modified from https://github.com/VAST-AI-Research/TripoSG/blob/main/triposg/models/transformers/triposg_transformer.py#L365
326
+ class PartCrafterDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
327
+ """
328
+ TripoSG: Diffusion model with a Transformer backbone.
329
+
330
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
331
+
332
+ Parameters:
333
+ num_attention_heads (`int`, *optional*, defaults to 16):
334
+ The number of heads to use for multi-head attention.
335
+ attention_head_dim (`int`, *optional*, defaults to 88):
336
+ The number of channels in each head.
337
+ in_channels (`int`, *optional*):
338
+ The number of channels in the input and output (specify if the input is **continuous**).
339
+ patch_size (`int`, *optional*):
340
+ The size of the patch to use for the input.
341
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
342
+ Activation function to use in feed-forward.
343
+ sample_size (`int`, *optional*):
344
+ The width of the latent images. This is fixed during training since it is used to learn a number of
345
+ position embeddings.
346
+ dropout (`float`, *optional*, defaults to 0.0):
347
+ The dropout probability to use.
348
+ cross_attention_dim (`int`, *optional*):
349
+ The number of dimension in the clip text embedding.
350
+ hidden_size (`int`, *optional*):
351
+ The size of hidden layer in the conditioning embedding layers.
352
+ num_layers (`int`, *optional*, defaults to 1):
353
+ The number of layers of Transformer blocks to use.
354
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
355
+ The ratio of the hidden layer size to the input size.
356
+ learn_sigma (`bool`, *optional*, defaults to `True`):
357
+ Whether to predict variance.
358
+ cross_attention_dim_t5 (`int`, *optional*):
359
+ The number dimensions in t5 text embedding.
360
+ pooled_projection_dim (`int`, *optional*):
361
+ The size of the pooled projection.
362
+ text_len (`int`, *optional*):
363
+ The length of the clip text embedding.
364
+ text_len_t5 (`int`, *optional*):
365
+ The length of the T5 text embedding.
366
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
367
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
368
+ """
369
+
370
+ _supports_gradient_checkpointing = True
371
+
372
+ @register_to_config
373
+ def __init__(
374
+ self,
375
+ num_attention_heads: int = 16,
376
+ width: int = 2048,
377
+ in_channels: int = 64,
378
+ num_layers: int = 21,
379
+ cross_attention_dim: int = 1024,
380
+ max_num_parts: int = 32,
381
+ enable_part_embedding=True,
382
+ enable_local_cross_attn: bool = True,
383
+ enable_global_cross_attn: bool = True,
384
+ global_attn_block_ids: Optional[List[int]] = None,
385
+ global_attn_block_id_range: Optional[List[int]] = None,
386
+ ):
387
+ super().__init__()
388
+ self.out_channels = in_channels
389
+ self.num_heads = num_attention_heads
390
+ self.inner_dim = width
391
+ self.mlp_ratio = 4.0
392
+
393
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
394
+ "positional",
395
+ inner_dim=self.inner_dim,
396
+ flip_sin_to_cos=False,
397
+ freq_shift=0,
398
+ time_embedding_dim=None,
399
+ )
400
+ self.time_proj = TimestepEmbedding(
401
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
402
+ )
403
+
404
+ if enable_part_embedding:
405
+ self.part_embedding = nn.Embedding(max_num_parts, self.inner_dim)
406
+ self.part_embedding.weight.data.normal_(mean=0.0, std=0.02)
407
+ self.enable_part_embedding = enable_part_embedding
408
+
409
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
410
+
411
+ self.blocks = nn.ModuleList(
412
+ [
413
+ DiTBlock(
414
+ dim=self.inner_dim,
415
+ num_attention_heads=self.config.num_attention_heads,
416
+ use_self_attention=True,
417
+ self_attention_norm_type="fp32_layer_norm",
418
+ use_cross_attention=True,
419
+ cross_attention_dim=cross_attention_dim,
420
+ cross_attention_norm_type=None,
421
+ activation_fn="gelu",
422
+ norm_type="fp32_layer_norm", # TODO
423
+ norm_eps=1e-5,
424
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
425
+ skip=layer > num_layers // 2,
426
+ skip_concat_front=True,
427
+ skip_norm_last=True, # this is an error
428
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
429
+ qkv_bias=False,
430
+ )
431
+ for layer in range(num_layers)
432
+ ]
433
+ )
434
+
435
+ self.norm_out = LayerNorm(self.inner_dim)
436
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
437
+
438
+ self.gradient_checkpointing = False
439
+
440
+ self.enable_local_cross_attn = enable_local_cross_attn
441
+ self.enable_global_cross_attn = enable_global_cross_attn
442
+
443
+ if global_attn_block_ids is None:
444
+ global_attn_block_ids = []
445
+ if global_attn_block_id_range is not None:
446
+ global_attn_block_ids = list(range(global_attn_block_id_range[0], global_attn_block_id_range[1] + 1))
447
+ self.global_attn_block_ids = global_attn_block_ids
448
+
449
+ if len(global_attn_block_ids) > 0:
450
+ # Override self-attention processors for global attention blocks
451
+ attn_processor_dict = {}
452
+ modified_attn_processor = []
453
+ for layer_id in range(num_layers):
454
+ for attn_id in [1, 2]:
455
+ if layer_id in global_attn_block_ids:
456
+ # apply to both self-attention and cross-attention
457
+ attn_processor_dict[f'blocks.{layer_id}.attn{attn_id}.processor'] = PartCrafterAttnProcessor()
458
+ modified_attn_processor.append(f'blocks.{layer_id}.attn{attn_id}.processor')
459
+ else:
460
+ attn_processor_dict[f'blocks.{layer_id}.attn{attn_id}.processor'] = TripoSGAttnProcessor2_0()
461
+ self.set_attn_processor(attn_processor_dict)
462
+ # logger.info(f"Modified {modified_attn_processor} to PartCrafterAttnProcessor")
463
+
464
+ def _set_gradient_checkpointing(
465
+ self,
466
+ enable: bool = False,
467
+ gradient_checkpointing_func: Optional[Callable] = None,
468
+ ):
469
+ # TODO: implement gradient checkpointing
470
+ self.gradient_checkpointing = enable
471
+
472
+ def _set_time_proj(
473
+ self,
474
+ time_embedding_type: str,
475
+ inner_dim: int,
476
+ flip_sin_to_cos: bool,
477
+ freq_shift: float,
478
+ time_embedding_dim: int,
479
+ ) -> Tuple[int, int]:
480
+ if time_embedding_type == "fourier":
481
+ time_embed_dim = time_embedding_dim or inner_dim * 2
482
+ if time_embed_dim % 2 != 0:
483
+ raise ValueError(
484
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
485
+ )
486
+ self.time_embed = GaussianFourierProjection(
487
+ time_embed_dim // 2,
488
+ set_W_to_weight=False,
489
+ log=False,
490
+ flip_sin_to_cos=flip_sin_to_cos,
491
+ )
492
+ timestep_input_dim = time_embed_dim
493
+ elif time_embedding_type == "positional":
494
+ time_embed_dim = time_embedding_dim or inner_dim * 4
495
+
496
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
497
+ timestep_input_dim = inner_dim
498
+ else:
499
+ raise ValueError(
500
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
501
+ )
502
+
503
+ return time_embed_dim, timestep_input_dim
504
+
505
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
506
+ def fuse_qkv_projections(self):
507
+ """
508
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
509
+ are fused. For cross-attention modules, key and value projection matrices are fused.
510
+
511
+ <Tip warning={true}>
512
+
513
+ This API is 🧪 experimental.
514
+
515
+ </Tip>
516
+ """
517
+ self.original_attn_processors = None
518
+
519
+ for _, attn_processor in self.attn_processors.items():
520
+ if "Added" in str(attn_processor.__class__.__name__):
521
+ raise ValueError(
522
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
523
+ )
524
+
525
+ self.original_attn_processors = self.attn_processors
526
+
527
+ for module in self.modules():
528
+ if isinstance(module, Attention):
529
+ module.fuse_projections(fuse=True)
530
+
531
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
532
+
533
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
534
+ def unfuse_qkv_projections(self):
535
+ """Disables the fused QKV projection if enabled.
536
+
537
+ <Tip warning={true}>
538
+
539
+ This API is 🧪 experimental.
540
+
541
+ </Tip>
542
+
543
+ """
544
+ if self.original_attn_processors is not None:
545
+ self.set_attn_processor(self.original_attn_processors)
546
+
547
+ @property
548
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
549
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
550
+ r"""
551
+ Returns:
552
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
553
+ indexed by its weight name.
554
+ """
555
+ # set recursively
556
+ processors = {}
557
+
558
+ def fn_recursive_add_processors(
559
+ name: str,
560
+ module: torch.nn.Module,
561
+ processors: Dict[str, AttentionProcessor],
562
+ ):
563
+ if hasattr(module, "get_processor"):
564
+ processors[f"{name}.processor"] = module.get_processor()
565
+
566
+ for sub_name, child in module.named_children():
567
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
568
+
569
+ return processors
570
+
571
+ for name, module in self.named_children():
572
+ fn_recursive_add_processors(name, module, processors)
573
+
574
+ return processors
575
+
576
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
577
+ def set_attn_processor(
578
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
579
+ ):
580
+ r"""
581
+ Sets the attention processor to use to compute attention.
582
+
583
+ Parameters:
584
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
585
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
586
+ for **all** `Attention` layers.
587
+
588
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
589
+ processor. This is strongly recommended when setting trainable attention processors.
590
+
591
+ """
592
+ count = len(self.attn_processors.keys())
593
+
594
+ if isinstance(processor, dict) and len(processor) != count:
595
+ raise ValueError(
596
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
597
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
598
+ )
599
+
600
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
601
+ if hasattr(module, "set_processor"):
602
+ if not isinstance(processor, dict):
603
+ module.set_processor(processor)
604
+ else:
605
+ module.set_processor(processor.pop(f"{name}.processor"))
606
+
607
+ for sub_name, child in module.named_children():
608
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
609
+
610
+ for name, module in self.named_children():
611
+ fn_recursive_attn_processor(name, module, processor)
612
+
613
+ def set_default_attn_processor(self):
614
+ """
615
+ Disables custom attention processors and sets the default attention implementation.
616
+ """
617
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
618
+
619
+ def forward(
620
+ self,
621
+ hidden_states: Optional[torch.Tensor],
622
+ timestep: Union[int, float, torch.LongTensor],
623
+ encoder_hidden_states: Optional[torch.Tensor] = None,
624
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
625
+ attention_kwargs: Optional[Dict[str, Any]] = None,
626
+ return_dict: bool = True,
627
+ ):
628
+ """
629
+ The [`HunyuanDiT2DModel`] forward method.
630
+
631
+ Args:
632
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
633
+ The input tensor.
634
+ timestep ( `torch.LongTensor`, *optional*):
635
+ Used to indicate denoising step.
636
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
637
+ Conditional embeddings for cross attention layer.
638
+ return_dict: bool
639
+ Whether to return a dictionary.
640
+ """
641
+
642
+ if attention_kwargs is not None:
643
+ attention_kwargs = attention_kwargs.copy()
644
+ lora_scale = attention_kwargs.pop("scale", 1.0)
645
+ else:
646
+ lora_scale = 1.0
647
+
648
+ if USE_PEFT_BACKEND:
649
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
650
+ scale_lora_layers(self, lora_scale)
651
+ else:
652
+ if (
653
+ attention_kwargs is not None
654
+ and attention_kwargs.get("scale", None) is not None
655
+ ):
656
+ logger.warning(
657
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
658
+ )
659
+
660
+ _, T, _ = hidden_states.shape
661
+
662
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
663
+ temb = self.time_proj(temb)
664
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
665
+
666
+ hidden_states = self.proj_in(hidden_states)
667
+
668
+ # T + 1 token
669
+ hidden_states = torch.cat([temb, hidden_states], dim=1) # (N, T+1, D)
670
+
671
+ if self.enable_part_embedding:
672
+ # Add part embedding
673
+ num_parts = attention_kwargs["num_parts"]
674
+ if isinstance(num_parts, torch.Tensor):
675
+ part_embeddings = []
676
+ for num_part in num_parts:
677
+ part_embedding = self.part_embedding(torch.arange(num_part, device=hidden_states.device)) # (n, D)
678
+ part_embeddings.append(part_embedding)
679
+ part_embedding = torch.cat(part_embeddings, dim=0) # (N, D)
680
+ elif isinstance(num_parts, int):
681
+ part_embedding = self.part_embedding(torch.arange(hidden_states.shape[0], device=hidden_states.device)) # (N, D)
682
+ else:
683
+ raise ValueError(
684
+ "num_parts must be a torch.Tensor or int, but got {}".format(type(num_parts))
685
+ )
686
+ hidden_states = hidden_states + part_embedding.unsqueeze(dim=1) # (N, T+1, D)
687
+
688
+ # prepare negative encoder_hidden_states
689
+ negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) if encoder_hidden_states is not None else None
690
+
691
+ skips = []
692
+ for layer, block in enumerate(self.blocks):
693
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
694
+ if (
695
+ (not self.enable_local_cross_attn)
696
+ and len(self.global_attn_block_ids) > 0
697
+ and (layer not in self.global_attn_block_ids)
698
+ ):
699
+ # If in non-global attention block and disable local cross attention, use negative encoder_hidden_states
700
+ # Do not inject control signal into non-global attention block
701
+ input_encoder_hidden_states = negative_encoder_hidden_states
702
+ elif (
703
+ (not self.enable_global_cross_attn)
704
+ and len(self.global_attn_block_ids) > 0
705
+ and (layer in self.global_attn_block_ids)
706
+ ):
707
+ # If in global attention block and disable global cross attention, use negative encoder_hidden_states
708
+ # Do not inject control signal into global attention block
709
+ input_encoder_hidden_states = negative_encoder_hidden_states
710
+ else:
711
+ input_encoder_hidden_states = encoder_hidden_states
712
+
713
+ if len(self.global_attn_block_ids) > 0 and (layer in self.global_attn_block_ids):
714
+ # Inject control signal into global attention block
715
+ input_attention_kwargs = attention_kwargs
716
+ else:
717
+ input_attention_kwargs = None
718
+
719
+ if self.training and self.gradient_checkpointing:
720
+
721
+ def create_custom_forward(module):
722
+ def custom_forward(*inputs):
723
+ return module(*inputs)
724
+
725
+ return custom_forward
726
+
727
+ ckpt_kwargs: Dict[str, Any] = (
728
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
729
+ )
730
+ hidden_states = torch.utils.checkpoint.checkpoint(
731
+ create_custom_forward(block),
732
+ hidden_states,
733
+ input_encoder_hidden_states,
734
+ temb,
735
+ image_rotary_emb,
736
+ skip,
737
+ input_attention_kwargs,
738
+ **ckpt_kwargs,
739
+ )
740
+ else:
741
+ hidden_states = block(
742
+ hidden_states,
743
+ encoder_hidden_states=input_encoder_hidden_states,
744
+ temb=temb,
745
+ image_rotary_emb=image_rotary_emb,
746
+ skip=skip,
747
+ attention_kwargs=input_attention_kwargs,
748
+ ) # (N, T+1, D)
749
+
750
+ if layer < self.config.num_layers // 2:
751
+ skips.append(hidden_states)
752
+
753
+ # final layer
754
+ hidden_states = self.norm_out(hidden_states)
755
+ hidden_states = hidden_states[:, -T:] # (N, T, D)
756
+ hidden_states = self.proj_out(hidden_states)
757
+
758
+ if USE_PEFT_BACKEND:
759
+ # remove `lora_scale` from each PEFT layer
760
+ unscale_lora_layers(self, lora_scale)
761
+
762
+ if not return_dict:
763
+ return (hidden_states,)
764
+
765
+ return Transformer1DModelOutput(sample=hidden_states)
766
+
767
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
768
+ def enable_forward_chunking(
769
+ self, chunk_size: Optional[int] = None, dim: int = 0
770
+ ) -> None:
771
+ """
772
+ Sets the attention processor to use [feed forward
773
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
774
+
775
+ Parameters:
776
+ chunk_size (`int`, *optional*):
777
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
778
+ over each tensor of dim=`dim`.
779
+ dim (`int`, *optional*, defaults to `0`):
780
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
781
+ or dim=1 (sequence length).
782
+ """
783
+ if dim not in [0, 1]:
784
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
785
+
786
+ # By default chunk size is 1
787
+ chunk_size = chunk_size or 1
788
+
789
+ def fn_recursive_feed_forward(
790
+ module: torch.nn.Module, chunk_size: int, dim: int
791
+ ):
792
+ if hasattr(module, "set_chunk_feed_forward"):
793
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
794
+
795
+ for child in module.children():
796
+ fn_recursive_feed_forward(child, chunk_size, dim)
797
+
798
+ for module in self.children():
799
+ fn_recursive_feed_forward(module, chunk_size, dim)
800
+
801
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
802
+ def disable_forward_chunking(self):
803
+ def fn_recursive_feed_forward(
804
+ module: torch.nn.Module, chunk_size: int, dim: int
805
+ ):
806
+ if hasattr(module, "set_chunk_feed_forward"):
807
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
808
+
809
+ for child in module.children():
810
+ fn_recursive_feed_forward(child, chunk_size, dim)
811
+
812
+ for module in self.children():
813
+ fn_recursive_feed_forward(module, None, 0)
src/pipelines/pipeline_partcrafter.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import torch
9
+ import trimesh
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers import (
16
+ BitImageProcessor,
17
+ Dinov2Model,
18
+ )
19
+ from ..utils.inference_utils import hierarchical_extract_geometry
20
+
21
+ from ..models.autoencoders import TripoSGVAEModel
22
+ from ..models.transformers import PartCrafterDiTModel
23
+ from .pipeline_partcrafter_output import PartCrafterPipelineOutput
24
+ from .pipeline_utils import TransformerDiffusionMixin
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
30
+ def retrieve_timesteps(
31
+ scheduler,
32
+ num_inference_steps: Optional[int] = None,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ timesteps: Optional[List[int]] = None,
35
+ sigmas: Optional[List[float]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
40
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
41
+
42
+ Args:
43
+ scheduler (`SchedulerMixin`):
44
+ The scheduler to get timesteps from.
45
+ num_inference_steps (`int`):
46
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
47
+ must be `None`.
48
+ device (`str` or `torch.device`, *optional*):
49
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
50
+ timesteps (`List[int]`, *optional*):
51
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
52
+ `num_inference_steps` and `sigmas` must be `None`.
53
+ sigmas (`List[float]`, *optional*):
54
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
55
+ `num_inference_steps` and `timesteps` must be `None`.
56
+
57
+ Returns:
58
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
59
+ second element is the number of inference steps.
60
+ """
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError(
63
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
64
+ )
65
+ if timesteps is not None:
66
+ accepts_timesteps = "timesteps" in set(
67
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
68
+ )
69
+ if not accepts_timesteps:
70
+ raise ValueError(
71
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
72
+ f" timestep schedules. Please check whether you are using the correct scheduler."
73
+ )
74
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
75
+ timesteps = scheduler.timesteps
76
+ num_inference_steps = len(timesteps)
77
+ elif sigmas is not None:
78
+ accept_sigmas = "sigmas" in set(
79
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
80
+ )
81
+ if not accept_sigmas:
82
+ raise ValueError(
83
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
84
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
85
+ )
86
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
87
+ timesteps = scheduler.timesteps
88
+ num_inference_steps = len(timesteps)
89
+ else:
90
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ return timesteps, num_inference_steps
93
+
94
+
95
+ class PartCrafterPipeline(DiffusionPipeline, TransformerDiffusionMixin):
96
+ """
97
+ Pipeline for image to 3D part-level object generation.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ vae: TripoSGVAEModel,
103
+ transformer: PartCrafterDiTModel,
104
+ scheduler: FlowMatchEulerDiscreteScheduler,
105
+ image_encoder_dinov2: Dinov2Model,
106
+ feature_extractor_dinov2: BitImageProcessor,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.register_modules(
111
+ vae=vae,
112
+ transformer=transformer,
113
+ scheduler=scheduler,
114
+ image_encoder_dinov2=image_encoder_dinov2,
115
+ feature_extractor_dinov2=feature_extractor_dinov2,
116
+ )
117
+
118
+ @property
119
+ def guidance_scale(self):
120
+ return self._guidance_scale
121
+
122
+ @property
123
+ def do_classifier_free_guidance(self):
124
+ return self._guidance_scale > 1
125
+
126
+ @property
127
+ def num_timesteps(self):
128
+ return self._num_timesteps
129
+
130
+ @property
131
+ def attention_kwargs(self):
132
+ return self._attention_kwargs
133
+
134
+ @property
135
+ def interrupt(self):
136
+ return self._interrupt
137
+
138
+ @property
139
+ def decode_progressive(self):
140
+ return self._decode_progressive
141
+
142
+ def encode_image(self, image, device, num_images_per_prompt):
143
+ dtype = next(self.image_encoder_dinov2.parameters()).dtype
144
+
145
+ if not isinstance(image, torch.Tensor):
146
+ image = self.feature_extractor_dinov2(image, return_tensors="pt").pixel_values
147
+
148
+ image = image.to(device=device, dtype=dtype)
149
+ image_embeds = self.image_encoder_dinov2(image).last_hidden_state
150
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
151
+ uncond_image_embeds = torch.zeros_like(image_embeds)
152
+
153
+ return image_embeds, uncond_image_embeds
154
+
155
+ def prepare_latents(
156
+ self,
157
+ batch_size,
158
+ num_tokens,
159
+ num_channels_latents,
160
+ dtype,
161
+ device,
162
+ generator,
163
+ latents: Optional[torch.Tensor] = None,
164
+ ):
165
+ shape = (batch_size, num_tokens, num_channels_latents)
166
+
167
+ if isinstance(generator, list) and len(generator) != batch_size:
168
+ raise ValueError(
169
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
170
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
171
+ )
172
+
173
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
174
+ return noise
175
+
176
+ @torch.no_grad()
177
+ def __call__(
178
+ self,
179
+ image: PipelineImageInput,
180
+ num_inference_steps: int = 50,
181
+ num_tokens: int = 2048,
182
+ timesteps: List[int] = None,
183
+ guidance_scale: float = 7.0,
184
+ num_images_per_prompt: int = 1,
185
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
186
+ latents: Optional[torch.FloatTensor] = None,
187
+ attention_kwargs: Optional[Dict[str, Any]] = None,
188
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
189
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
190
+ bounds: Union[Tuple[float], List[float], float] = (-1.005, -1.005, -1.005, 1.005, 1.005, 1.005),
191
+ dense_octree_depth: int = 8,
192
+ hierarchical_octree_depth: int = 9,
193
+ max_num_expanded_coords: int = 1e8,
194
+ flash_octree_depth: int = 9,
195
+ use_flash_decoder: bool = True,
196
+ return_dict: bool = True,
197
+ ):
198
+ # 1. Define call parameters
199
+ self._guidance_scale = guidance_scale
200
+ self._attention_kwargs = attention_kwargs
201
+ self._interrupt = False
202
+
203
+ # 2. Define call parameters
204
+ if isinstance(image, PIL.Image.Image):
205
+ batch_size = 1
206
+ elif isinstance(image, list):
207
+ batch_size = len(image)
208
+ elif isinstance(image, torch.Tensor):
209
+ batch_size = image.shape[0]
210
+ else:
211
+ raise ValueError("Invalid input type for image")
212
+
213
+ device = self._execution_device
214
+ dtype = self.image_encoder_dinov2.dtype
215
+
216
+ # 3. Encode condition
217
+ image_embeds, negative_image_embeds = self.encode_image(
218
+ image, device, num_images_per_prompt
219
+ )
220
+
221
+ if self.do_classifier_free_guidance:
222
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
223
+
224
+ # 4. Prepare timesteps
225
+ timesteps, num_inference_steps = retrieve_timesteps(
226
+ self.scheduler, num_inference_steps, device, timesteps
227
+ )
228
+ num_warmup_steps = max(
229
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
230
+ )
231
+ self._num_timesteps = len(timesteps)
232
+
233
+ # 5. Prepare latent variables
234
+ num_channels_latents = self.transformer.config.in_channels
235
+ latents = self.prepare_latents(
236
+ batch_size * num_images_per_prompt,
237
+ num_tokens,
238
+ num_channels_latents,
239
+ image_embeds.dtype,
240
+ device,
241
+ generator,
242
+ latents,
243
+ )
244
+
245
+ # 6. Denoising loop
246
+ self.set_progress_bar_config(
247
+ desc="Denoising",
248
+ ncols=125,
249
+ disable=self._progress_bar_config['disable'] if hasattr(self, '_progress_bar_config') else False,
250
+ )
251
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
252
+ for i, t in enumerate(timesteps):
253
+ if self.interrupt:
254
+ continue
255
+
256
+ # expand the latents if we are doing classifier free guidance
257
+ latent_model_input = (
258
+ torch.cat([latents] * 2)
259
+ if self.do_classifier_free_guidance
260
+ else latents
261
+ )
262
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
263
+ timestep = t.expand(latent_model_input.shape[0])
264
+
265
+ noise_pred = self.transformer(
266
+ latent_model_input,
267
+ timestep,
268
+ encoder_hidden_states=image_embeds,
269
+ attention_kwargs=attention_kwargs,
270
+ return_dict=False,
271
+ )[0].to(dtype)
272
+
273
+ # perform guidance
274
+ if self.do_classifier_free_guidance:
275
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
276
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
277
+ noise_pred_image - noise_pred_uncond
278
+ )
279
+
280
+ # compute the previous noisy sample x_t -> x_t-1
281
+ latents_dtype = latents.dtype
282
+ latents = self.scheduler.step(
283
+ noise_pred, t, latents, return_dict=False
284
+ )[0]
285
+
286
+ if latents.dtype != latents_dtype:
287
+ if torch.backends.mps.is_available():
288
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
289
+ latents = latents.to(latents_dtype)
290
+
291
+ if callback_on_step_end is not None:
292
+ callback_kwargs = {}
293
+ for k in callback_on_step_end_tensor_inputs:
294
+ callback_kwargs[k] = locals()[k]
295
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
296
+
297
+ latents = callback_outputs.pop("latents", latents)
298
+ image_embeds_1 = callback_outputs.pop(
299
+ "image_embeds_1", image_embeds_1
300
+ )
301
+ negative_image_embeds_1 = callback_outputs.pop(
302
+ "negative_image_embeds_1", negative_image_embeds_1
303
+ )
304
+ image_embeds_2 = callback_outputs.pop(
305
+ "image_embeds_2", image_embeds_2
306
+ )
307
+ negative_image_embeds_2 = callback_outputs.pop(
308
+ "negative_image_embeds_2", negative_image_embeds_2
309
+ )
310
+
311
+ # call the callback, if provided
312
+ if i == len(timesteps) - 1 or (
313
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
314
+ ):
315
+ progress_bar.update()
316
+
317
+
318
+ # 7. decoder mesh
319
+ self.vae.set_flash_decoder()
320
+ output, meshes = [], []
321
+ self.set_progress_bar_config(
322
+ desc="Decoding",
323
+ ncols=125,
324
+ disable=self._progress_bar_config['disable'] if hasattr(self, '_progress_bar_config') else False,
325
+ )
326
+ with self.progress_bar(total=batch_size) as progress_bar:
327
+ for i in range(batch_size):
328
+ geometric_func = lambda x: self.vae.decode(latents[i].unsqueeze(0), sampled_points=x).sample
329
+ try:
330
+ mesh_v_f = hierarchical_extract_geometry(
331
+ geometric_func,
332
+ device,
333
+ dtype=latents.dtype,
334
+ bounds=bounds,
335
+ dense_octree_depth=dense_octree_depth,
336
+ hierarchical_octree_depth=hierarchical_octree_depth,
337
+ max_num_expanded_coords=max_num_expanded_coords,
338
+ # verbose=True
339
+ )
340
+ mesh = trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1])
341
+ except:
342
+ mesh_v_f = None
343
+ mesh = None
344
+ output.append(mesh_v_f)
345
+ meshes.append(mesh)
346
+ progress_bar.update()
347
+
348
+ # Offload all models
349
+ self.maybe_free_model_hooks()
350
+
351
+ if not return_dict:
352
+ return (output, meshes)
353
+
354
+ return PartCrafterPipelineOutput(samples=output, meshes=meshes)
355
+
src/pipelines/pipeline_partcrafter_output.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import trimesh
7
+ from diffusers.utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class PartCrafterPipelineOutput(BaseOutput):
12
+ r"""
13
+ Output class for ShapeDiff pipelines.
14
+ """
15
+
16
+ samples: torch.Tensor
17
+ meshes: List[trimesh.Trimesh]
src/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import logging
2
+
3
+ logger = logging.get_logger(__name__)
4
+
5
+
6
+ class TransformerDiffusionMixin:
7
+ r"""
8
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
9
+ """
10
+
11
+ def enable_vae_slicing(self):
12
+ r"""
13
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
14
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
15
+ """
16
+ self.vae.enable_slicing()
17
+
18
+ def disable_vae_slicing(self):
19
+ r"""
20
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
21
+ computing decoding in one step.
22
+ """
23
+ self.vae.disable_slicing()
24
+
25
+ def enable_vae_tiling(self):
26
+ r"""
27
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
28
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
29
+ processing larger images.
30
+ """
31
+ self.vae.enable_tiling()
32
+
33
+ def disable_vae_tiling(self):
34
+ r"""
35
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
36
+ computing decoding in one step.
37
+ """
38
+ self.vae.disable_tiling()
39
+
40
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
41
+ """
42
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
43
+ are fused. For cross-attention modules, key and value projection matrices are fused.
44
+
45
+ <Tip warning={true}>
46
+
47
+ This API is 🧪 experimental.
48
+
49
+ </Tip>
50
+
51
+ Args:
52
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
53
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
54
+ """
55
+ self.fusing_transformer = False
56
+ self.fusing_vae = False
57
+
58
+ if transformer:
59
+ self.fusing_transformer = True
60
+ self.transformer.fuse_qkv_projections()
61
+
62
+ if vae:
63
+ self.fusing_vae = True
64
+ self.vae.fuse_qkv_projections()
65
+
66
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
67
+ """Disable QKV projection fusion if enabled.
68
+
69
+ <Tip warning={true}>
70
+
71
+ This API is 🧪 experimental.
72
+
73
+ </Tip>
74
+
75
+ Args:
76
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
77
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
78
+
79
+ """
80
+ if transformer:
81
+ if not self.fusing_transformer:
82
+ logger.warning(
83
+ "The UNet was not initially fused for QKV projections. Doing nothing."
84
+ )
85
+ else:
86
+ self.transformer.unfuse_qkv_projections()
87
+ self.fusing_transformer = False
88
+
89
+ if vae:
90
+ if not self.fusing_vae:
91
+ logger.warning(
92
+ "The VAE was not initially fused for QKV projections. Doing nothing."
93
+ )
94
+ else:
95
+ self.vae.unfuse_qkv_projections()
96
+ self.fusing_vae = False
src/schedulers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .scheduling_rectified_flow import (
2
+ RectifiedFlowScheduler,
3
+ compute_density_for_timestep_sampling,
4
+ compute_loss_weighting,
5
+ )
src/schedulers/scheduling_rectified_flow.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
13
+ from diffusers.utils import BaseOutput, logging
14
+ from torch.distributions import LogisticNormal
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ # TODO: may move to training_utils.py
20
+ def compute_density_for_timestep_sampling(
21
+ weighting_scheme: str,
22
+ batch_size: int,
23
+ logit_mean: float = 0.0,
24
+ logit_std: float = 1.0,
25
+ mode_scale: float = None,
26
+ ):
27
+ if weighting_scheme == "logit_normal":
28
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
29
+ u = torch.normal(
30
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
31
+ )
32
+ u = torch.nn.functional.sigmoid(u)
33
+ elif weighting_scheme == "logit_normal_dist":
34
+ u = (
35
+ LogisticNormal(loc=logit_mean, scale=logit_std)
36
+ .sample((batch_size,))[:, 0]
37
+ .to("cpu")
38
+ )
39
+ elif weighting_scheme == "mode":
40
+ u = torch.rand(size=(batch_size,), device="cpu")
41
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
42
+ else:
43
+ u = torch.rand(size=(batch_size,), device="cpu")
44
+ return u
45
+
46
+
47
+ def compute_loss_weighting(weighting_scheme: str, sigmas=None):
48
+ """
49
+ Computes loss weighting scheme for SD3 training.
50
+
51
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
52
+
53
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
54
+ """
55
+ if weighting_scheme == "sigma_sqrt":
56
+ weighting = (sigmas**-2.0).float()
57
+ elif weighting_scheme == "cosmap":
58
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
59
+ weighting = 2 / (math.pi * bot)
60
+ else:
61
+ weighting = torch.ones_like(sigmas)
62
+ return weighting
63
+
64
+
65
+ @dataclass
66
+ class RectifiedFlowSchedulerOutput(BaseOutput):
67
+ """
68
+ Output class for the scheduler's `step` function output.
69
+
70
+ Args:
71
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
72
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
73
+ denoising loop.
74
+ """
75
+
76
+ prev_sample: torch.FloatTensor
77
+
78
+
79
+ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
80
+ """
81
+ The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
82
+
83
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
84
+ methods the library implements for all schedulers such as loading and saving.
85
+
86
+ Args:
87
+ num_train_timesteps (`int`, defaults to 1000):
88
+ The number of diffusion steps to train the model.
89
+ timestep_spacing (`str`, defaults to `"linspace"`):
90
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
91
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
92
+ shift (`float`, defaults to 1.0):
93
+ The shift value for the timestep schedule.
94
+ """
95
+
96
+ _compatibles = []
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ shift: float = 1.0,
104
+ use_dynamic_shifting: bool = False,
105
+ ):
106
+ # pre-compute timesteps and sigmas; no use in fact
107
+ # NOTE that shape diffusion sample timesteps randomly or in a distribution,
108
+ # instead of sampling from the pre-defined linspace
109
+ timesteps = np.array(
110
+ [
111
+ (1.0 - i / num_train_timesteps) * num_train_timesteps
112
+ for i in range(num_train_timesteps)
113
+ ]
114
+ )
115
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
116
+
117
+ sigmas = timesteps / num_train_timesteps
118
+ if not use_dynamic_shifting:
119
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
120
+ sigmas = self.time_shift(sigmas)
121
+
122
+ self.timesteps = sigmas * num_train_timesteps
123
+
124
+ self._step_index = None
125
+ self._begin_index = None
126
+
127
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
128
+
129
+ @property
130
+ def step_index(self):
131
+ """
132
+ The index counter for current timestep. It will increase 1 after each scheduler step.
133
+ """
134
+ return self._step_index
135
+
136
+ @property
137
+ def begin_index(self):
138
+ """
139
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
140
+ """
141
+ return self._begin_index
142
+
143
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
144
+ def set_begin_index(self, begin_index: int = 0):
145
+ """
146
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
147
+
148
+ Args:
149
+ begin_index (`int`):
150
+ The begin index for the scheduler.
151
+ """
152
+ self._begin_index = begin_index
153
+
154
+ def _sigma_to_t(self, sigma):
155
+ return sigma * self.config.num_train_timesteps
156
+
157
+ def _t_to_sigma(self, timestep):
158
+ return timestep / self.config.num_train_timesteps
159
+
160
+ def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
161
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
162
+
163
+ def time_shift(self, t: torch.Tensor):
164
+ return self.config.shift * t / (1 + (self.config.shift - 1) * t)
165
+
166
+ def set_timesteps(
167
+ self,
168
+ num_inference_steps: int = None,
169
+ device: Union[str, torch.device] = None,
170
+ sigmas: Optional[List[float]] = None,
171
+ mu: Optional[float] = None,
172
+ ):
173
+ """
174
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
175
+
176
+ Args:
177
+ num_inference_steps (`int`):
178
+ The number of diffusion steps used when generating samples with a pre-trained model.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ """
182
+
183
+ if self.config.use_dynamic_shifting and mu is None:
184
+ raise ValueError(
185
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
186
+ )
187
+
188
+ if sigmas is None:
189
+ self.num_inference_steps = num_inference_steps
190
+ timesteps = np.array(
191
+ [
192
+ (1.0 - i / num_inference_steps) * self.config.num_train_timesteps
193
+ for i in range(num_inference_steps)
194
+ ]
195
+ ) # different from the original code in SD3
196
+ sigmas = timesteps / self.config.num_train_timesteps
197
+
198
+ if self.config.use_dynamic_shifting:
199
+ sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
200
+ else:
201
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
202
+
203
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
204
+ timesteps = sigmas * self.config.num_train_timesteps
205
+
206
+ self.timesteps = timesteps.to(device=device)
207
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
208
+
209
+ self._step_index = None
210
+ self._begin_index = None
211
+
212
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
213
+ if schedule_timesteps is None:
214
+ schedule_timesteps = self.timesteps
215
+
216
+ indices = (schedule_timesteps == timestep).nonzero()
217
+
218
+ # The sigma index that is taken for the **very** first `step`
219
+ # is always the second index (or the last index if there is only 1)
220
+ # This way we can ensure we don't accidentally skip a sigma in
221
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
222
+ pos = 1 if len(indices) > 1 else 0
223
+
224
+ return indices[pos].item()
225
+
226
+ def _init_step_index(self, timestep):
227
+ if self.begin_index is None:
228
+ if isinstance(timestep, torch.Tensor):
229
+ timestep = timestep.to(self.timesteps.device)
230
+ self._step_index = self.index_for_timestep(timestep)
231
+ else:
232
+ self._step_index = self._begin_index
233
+
234
+ def step(
235
+ self,
236
+ model_output: torch.FloatTensor,
237
+ timestep: Union[float, torch.FloatTensor],
238
+ sample: torch.FloatTensor,
239
+ s_churn: float = 0.0,
240
+ s_tmin: float = 0.0,
241
+ s_tmax: float = float("inf"),
242
+ s_noise: float = 1.0,
243
+ generator: Optional[torch.Generator] = None,
244
+ return_dict: bool = True,
245
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
246
+ """
247
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
248
+ process from the learned model outputs (most often the predicted noise).
249
+
250
+ Args:
251
+ model_output (`torch.FloatTensor`):
252
+ The direct output from learned diffusion model.
253
+ timestep (`float`):
254
+ The current discrete timestep in the diffusion chain.
255
+ sample (`torch.FloatTensor`):
256
+ A current instance of a sample created by the diffusion process.
257
+ s_churn (`float`):
258
+ s_tmin (`float`):
259
+ s_tmax (`float`):
260
+ s_noise (`float`, defaults to 1.0):
261
+ Scaling factor for noise added to the sample.
262
+ generator (`torch.Generator`, *optional*):
263
+ A random number generator.
264
+ return_dict (`bool`):
265
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
266
+ tuple.
267
+
268
+ Returns:
269
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
270
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
271
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
272
+ """
273
+
274
+ if (
275
+ isinstance(timestep, int)
276
+ or isinstance(timestep, torch.IntTensor)
277
+ or isinstance(timestep, torch.LongTensor)
278
+ ):
279
+ raise ValueError(
280
+ (
281
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
282
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
283
+ " one of the `scheduler.timesteps` as a timestep."
284
+ ),
285
+ )
286
+
287
+ if self.step_index is None:
288
+ self._init_step_index(timestep)
289
+
290
+ # Upcast to avoid precision issues when computing prev_sample
291
+ sample = sample.to(torch.float32)
292
+
293
+ sigma = self.sigmas[self.step_index]
294
+ sigma_next = self.sigmas[self.step_index + 1]
295
+
296
+ # Here different directions are used for the flow matching
297
+ prev_sample = sample + (sigma - sigma_next) * model_output
298
+
299
+ # Cast sample back to model compatible dtype
300
+ prev_sample = prev_sample.to(model_output.dtype)
301
+
302
+ # upon completion increase step index by one
303
+ self._step_index += 1
304
+
305
+ if not return_dict:
306
+ return (prev_sample,)
307
+
308
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
309
+
310
+ def scale_noise(
311
+ self,
312
+ original_samples: torch.Tensor,
313
+ noise: torch.Tensor,
314
+ timesteps: torch.IntTensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Forward function for the noise scaling in the flow matching.
318
+ """
319
+ sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
320
+
321
+ while len(sigmas.shape) < len(original_samples.shape):
322
+ sigmas = sigmas.unsqueeze(-1)
323
+
324
+ return (1.0 - sigmas) * original_samples + sigmas * noise
325
+
326
+ def __len__(self):
327
+ return self.config.num_train_timesteps
src/train_partcrafter.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore") # ignore all warnings
3
+ import diffusers.utils.logging as diffusion_logging
4
+ diffusion_logging.set_verbosity_error() # ignore diffusers warnings
5
+
6
+ from src.utils.typing_utils import *
7
+
8
+ import os
9
+ import argparse
10
+ import logging
11
+ import time
12
+ import math
13
+ import gc
14
+ from packaging import version
15
+
16
+ import trimesh
17
+ from PIL import Image
18
+ import numpy as np
19
+ import wandb
20
+ from tqdm import tqdm
21
+
22
+ import torch
23
+ import torch.nn.functional as tF
24
+ import accelerate
25
+ from accelerate import Accelerator
26
+ from accelerate.logging import get_logger as get_accelerate_logger
27
+ from accelerate import DataLoaderConfiguration, DeepSpeedPlugin
28
+ from diffusers.training_utils import (
29
+ compute_density_for_timestep_sampling,
30
+ compute_loss_weighting_for_sd3
31
+ )
32
+
33
+ from transformers import (
34
+ BitImageProcessor,
35
+ Dinov2Model,
36
+ )
37
+ from src.schedulers import RectifiedFlowScheduler
38
+ from src.models.autoencoders import TripoSGVAEModel
39
+ from src.models.transformers import PartCrafterDiTModel
40
+ from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
41
+
42
+ from src.datasets import (
43
+ ObjaversePartDataset,
44
+ BatchedObjaversePartDataset,
45
+ MultiEpochsDataLoader,
46
+ yield_forever
47
+ )
48
+ from src.utils.data_utils import get_colored_mesh_composition
49
+ from src.utils.train_utils import (
50
+ MyEMAModel,
51
+ get_configs,
52
+ get_optimizer,
53
+ get_lr_scheduler,
54
+ save_experiment_params,
55
+ save_model_architecture,
56
+ )
57
+ from src.utils.render_utils import (
58
+ render_views_around_mesh,
59
+ render_normal_views_around_mesh,
60
+ make_grid_for_images_or_videos,
61
+ export_renderings
62
+ )
63
+ from src.utils.metric_utils import compute_cd_and_f_score_in_training
64
+
65
+ def main():
66
+ PROJECT_NAME = "PartCrafter"
67
+
68
+ parser = argparse.ArgumentParser(
69
+ description="Train a diffusion model for 3D object generation",
70
+ )
71
+
72
+ parser.add_argument(
73
+ "--config",
74
+ type=str,
75
+ required=True,
76
+ help="Path to the config file"
77
+ )
78
+ parser.add_argument(
79
+ "--tag",
80
+ type=str,
81
+ default=None,
82
+ help="Tag that refers to the current experiment"
83
+ )
84
+ parser.add_argument(
85
+ "--output_dir",
86
+ type=str,
87
+ default="output",
88
+ help="Path to the output directory"
89
+ )
90
+ parser.add_argument(
91
+ "--resume_from_iter",
92
+ type=int,
93
+ default=None,
94
+ help="The iteration to load the checkpoint from"
95
+ )
96
+ parser.add_argument(
97
+ "--seed",
98
+ type=int,
99
+ default=0,
100
+ help="Seed for the PRNG"
101
+ )
102
+ parser.add_argument(
103
+ "--offline_wandb",
104
+ action="store_true",
105
+ help="Use offline WandB for experiment tracking"
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--max_train_steps",
110
+ type=int,
111
+ default=None,
112
+ help="The max iteration step for training"
113
+ )
114
+ parser.add_argument(
115
+ "--max_val_steps",
116
+ type=int,
117
+ default=2,
118
+ help="The max iteration step for validation"
119
+ )
120
+ parser.add_argument(
121
+ "--num_workers",
122
+ type=int,
123
+ default=32,
124
+ help="The number of processed spawned by the batch provider"
125
+ )
126
+ parser.add_argument(
127
+ "--pin_memory",
128
+ action="store_true",
129
+ help="Pin memory for the data loader"
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--use_ema",
134
+ action="store_true",
135
+ help="Use EMA model for training"
136
+ )
137
+ parser.add_argument(
138
+ "--scale_lr",
139
+ action="store_true",
140
+ help="Scale lr with total batch size (base batch size: 256)"
141
+ )
142
+ parser.add_argument(
143
+ "--max_grad_norm",
144
+ type=float,
145
+ default=1.,
146
+ help="Max gradient norm for gradient clipping"
147
+ )
148
+ parser.add_argument(
149
+ "--gradient_accumulation_steps",
150
+ type=int,
151
+ default=1,
152
+ help="Number of updates steps to accumulate before performing a backward/update pass"
153
+ )
154
+ parser.add_argument(
155
+ "--mixed_precision",
156
+ type=str,
157
+ default="fp16",
158
+ choices=["no", "fp16", "bf16"],
159
+ help="Type of mixed precision training"
160
+ )
161
+ parser.add_argument(
162
+ "--allow_tf32",
163
+ action="store_true",
164
+ help="Enable TF32 for faster training on Ampere GPUs"
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--val_guidance_scales",
169
+ type=list,
170
+ nargs="+",
171
+ default=[7.0],
172
+ help="CFG scale used for validation"
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--use_deepspeed",
177
+ action="store_true",
178
+ help="Use DeepSpeed for training"
179
+ )
180
+ parser.add_argument(
181
+ "--zero_stage",
182
+ type=int,
183
+ default=1,
184
+ choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed
185
+ help="ZeRO stage type for DeepSpeed"
186
+ )
187
+
188
+ parser.add_argument(
189
+ "--from_scratch",
190
+ action="store_true",
191
+ help="Train from scratch"
192
+ )
193
+ parser.add_argument(
194
+ "--load_pretrained_model",
195
+ type=str,
196
+ default=None,
197
+ help="Tag of a pretrained PartCrafterDiTModel in this project"
198
+ )
199
+ parser.add_argument(
200
+ "--load_pretrained_model_ckpt",
201
+ type=int,
202
+ default=-1,
203
+ help="Iteration of the pretrained PartCrafterDiTModel checkpoint"
204
+ )
205
+
206
+ # Parse the arguments
207
+ args, extras = parser.parse_known_args()
208
+ # Parse the config file
209
+ configs = get_configs(args.config, extras) # change yaml configs by `extras`
210
+
211
+ args.val_guidance_scales = [float(x[0]) if isinstance(x, list) else float(x) for x in args.val_guidance_scales]
212
+ if args.max_val_steps > 0:
213
+ # If enable validation, the max_val_steps must be a multiple of nrow
214
+ # Always keep validation batchsize 1
215
+ divider = configs["val"]["nrow"]
216
+ args.max_val_steps = max(args.max_val_steps, divider)
217
+ if args.max_val_steps % divider != 0:
218
+ args.max_val_steps = (args.max_val_steps // divider + 1) * divider
219
+
220
+ # Create an experiment directory using the `tag`
221
+ if args.tag is None:
222
+ args.tag = time.strftime("%Y%m%d_%H_%M_%S")
223
+ exp_dir = os.path.join(args.output_dir, args.tag)
224
+ ckpt_dir = os.path.join(exp_dir, "checkpoints")
225
+ eval_dir = os.path.join(exp_dir, "evaluations")
226
+ os.makedirs(ckpt_dir, exist_ok=True)
227
+ os.makedirs(eval_dir, exist_ok=True)
228
+
229
+ # Initialize the logger
230
+ logging.basicConfig(
231
+ format="%(asctime)s - %(message)s",
232
+ datefmt="%Y/%m/%d %H:%M:%S",
233
+ level=logging.INFO
234
+ )
235
+ logger = get_accelerate_logger(__name__, log_level="INFO")
236
+ file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file
237
+ file_handler.setFormatter(logging.Formatter(
238
+ fmt="%(asctime)s - %(message)s",
239
+ datefmt="%Y/%m/%d %H:%M:%S"
240
+ ))
241
+ logger.logger.addHandler(file_handler)
242
+ logger.logger.propagate = True # propagate to the root logger (console)
243
+
244
+ # Set DeepSpeed config
245
+ if args.use_deepspeed:
246
+ deepspeed_plugin = DeepSpeedPlugin(
247
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
248
+ gradient_clipping=args.max_grad_norm,
249
+ zero_stage=int(args.zero_stage),
250
+ offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable
251
+ )
252
+ else:
253
+ deepspeed_plugin = None
254
+
255
+ # Initialize the accelerator
256
+ accelerator = Accelerator(
257
+ project_dir=exp_dir,
258
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
259
+ mixed_precision=args.mixed_precision,
260
+ split_batches=False, # batch size per GPU
261
+ dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory),
262
+ deepspeed_plugin=deepspeed_plugin,
263
+ )
264
+ logger.info(f"Accelerator state:\n{accelerator.state}\n")
265
+
266
+ # Set the random seed
267
+ if args.seed >= 0:
268
+ accelerate.utils.set_seed(args.seed)
269
+ logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n")
270
+
271
+ # Enable TF32 for faster training on Ampere GPUs
272
+ if args.allow_tf32:
273
+ torch.backends.cuda.matmul.allow_tf32 = True
274
+
275
+ train_dataset = BatchedObjaversePartDataset(
276
+ configs=configs,
277
+ batch_size=configs["train"]["batch_size_per_gpu"],
278
+ is_main_process=accelerator.is_main_process,
279
+ shuffle=True,
280
+ training=True,
281
+ )
282
+ val_dataset = ObjaversePartDataset(
283
+ configs=configs,
284
+ training=False,
285
+ )
286
+ train_loader = MultiEpochsDataLoader(
287
+ train_dataset,
288
+ batch_size=configs["train"]["batch_size_per_gpu"],
289
+ num_workers=args.num_workers,
290
+ drop_last=True,
291
+ pin_memory=args.pin_memory,
292
+ collate_fn=train_dataset.collate_fn,
293
+ )
294
+ val_loader = MultiEpochsDataLoader(
295
+ val_dataset,
296
+ batch_size=configs["val"]["batch_size_per_gpu"],
297
+ num_workers=args.num_workers,
298
+ drop_last=True,
299
+ pin_memory=args.pin_memory,
300
+ )
301
+ random_val_loader = MultiEpochsDataLoader(
302
+ val_dataset,
303
+ batch_size=configs["val"]["batch_size_per_gpu"],
304
+ shuffle=True,
305
+ num_workers=args.num_workers,
306
+ drop_last=True,
307
+ pin_memory=args.pin_memory,
308
+ )
309
+
310
+ logger.info(f"Loaded [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n")
311
+
312
+ # Compute the effective batch size and scale learning rate
313
+ total_batch_size = configs["train"]["batch_size_per_gpu"] * \
314
+ accelerator.num_processes * args.gradient_accumulation_steps
315
+ configs["train"]["total_batch_size"] = total_batch_size
316
+ if args.scale_lr:
317
+ configs["optimizer"]["lr"] *= (total_batch_size / 256)
318
+ configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"]
319
+
320
+ # Initialize the model
321
+ logger.info("Initializing the model...")
322
+ vae = TripoSGVAEModel.from_pretrained(
323
+ configs["model"]["pretrained_model_name_or_path"],
324
+ subfolder="vae"
325
+ )
326
+ feature_extractor_dinov2 = BitImageProcessor.from_pretrained(
327
+ configs["model"]["pretrained_model_name_or_path"],
328
+ subfolder="feature_extractor_dinov2"
329
+ )
330
+ image_encoder_dinov2 = Dinov2Model.from_pretrained(
331
+ configs["model"]["pretrained_model_name_or_path"],
332
+ subfolder="image_encoder_dinov2"
333
+ )
334
+
335
+ enable_part_embedding = configs["model"]["transformer"].get("enable_part_embedding", True)
336
+ enable_local_cross_attn = configs["model"]["transformer"].get("enable_local_cross_attn", True)
337
+ enable_global_cross_attn = configs["model"]["transformer"].get("enable_global_cross_attn", True)
338
+ global_attn_block_ids = configs["model"]["transformer"].get("global_attn_block_ids", None)
339
+ if global_attn_block_ids is not None:
340
+ global_attn_block_ids = list(global_attn_block_ids)
341
+ global_attn_block_id_range = configs["model"]["transformer"].get("global_attn_block_id_range", None)
342
+ if global_attn_block_id_range is not None:
343
+ global_attn_block_id_range = list(global_attn_block_id_range)
344
+ if args.from_scratch:
345
+ logger.info(f"Initialize PartCrafterDiTModel from scratch\n")
346
+ transformer = PartCrafterDiTModel.from_config(
347
+ os.path.join(
348
+ configs["model"]["pretrained_model_name_or_path"],
349
+ "transformer"
350
+ ),
351
+ enable_part_embedding=enable_part_embedding,
352
+ enable_local_cross_attn=enable_local_cross_attn,
353
+ enable_global_cross_attn=enable_global_cross_attn,
354
+ global_attn_block_ids=global_attn_block_ids,
355
+ global_attn_block_id_range=global_attn_block_id_range,
356
+ )
357
+ elif args.load_pretrained_model is None:
358
+ logger.info(f"Load pretrained TripoSGDiTModel to initialize PartCrafterDiTModel from [{configs['model']['pretrained_model_name_or_path']}]\n")
359
+ transformer, loading_info = PartCrafterDiTModel.from_pretrained(
360
+ configs["model"]["pretrained_model_name_or_path"],
361
+ subfolder="transformer",
362
+ low_cpu_mem_usage=False,
363
+ output_loading_info=True,
364
+ enable_part_embedding=enable_part_embedding,
365
+ enable_local_cross_attn=enable_local_cross_attn,
366
+ enable_global_cross_attn=enable_global_cross_attn,
367
+ global_attn_block_ids=global_attn_block_ids,
368
+ global_attn_block_id_range=global_attn_block_id_range,
369
+ )
370
+ else:
371
+ logger.info(f"Load PartCrafterDiTModel EMA checkpoint from [{args.load_pretrained_model}] iteration [{args.load_pretrained_model_ckpt:06d}]\n")
372
+ path = os.path.join(
373
+ args.output_dir,
374
+ args.load_pretrained_model,
375
+ "checkpoints",
376
+ f"{args.load_pretrained_model_ckpt:06d}"
377
+ )
378
+ transformer, loading_info = PartCrafterDiTModel.from_pretrained(
379
+ path,
380
+ subfolder="transformer_ema",
381
+ low_cpu_mem_usage=False,
382
+ output_loading_info=True,
383
+ enable_part_embedding=enable_part_embedding,
384
+ enable_local_cross_attn=enable_local_cross_attn,
385
+ enable_global_cross_attn=enable_global_cross_attn,
386
+ global_attn_block_ids=global_attn_block_ids,
387
+ global_attn_block_id_range=global_attn_block_id_range,
388
+ )
389
+ if not args.from_scratch:
390
+ for v in loading_info.values():
391
+ if v and len(v) > 0:
392
+ logger.info(f"Loading info of PartCrafterDiTModel: {loading_info}\n")
393
+ break
394
+
395
+ noise_scheduler = RectifiedFlowScheduler.from_pretrained(
396
+ configs["model"]["pretrained_model_name_or_path"],
397
+ subfolder="scheduler"
398
+ )
399
+
400
+ if args.use_ema:
401
+ ema_transformer = MyEMAModel(
402
+ transformer.parameters(),
403
+ model_cls=PartCrafterDiTModel,
404
+ model_config=transformer.config,
405
+ **configs["train"]["ema_kwargs"]
406
+ )
407
+
408
+ # Freeze VAE and image encoder
409
+ vae.requires_grad_(False)
410
+ image_encoder_dinov2.requires_grad_(False)
411
+ vae.eval()
412
+ image_encoder_dinov2.eval()
413
+
414
+ trainable_modules = configs["train"].get("trainable_modules", None)
415
+ if trainable_modules is None:
416
+ transformer.requires_grad_(True)
417
+ else:
418
+ trainable_module_names = []
419
+ transformer.requires_grad_(False)
420
+ for name, module in transformer.named_modules():
421
+ for module_name in tuple(trainable_modules.split(",")):
422
+ if module_name in name:
423
+ for params in module.parameters():
424
+ params.requires_grad = True
425
+ trainable_module_names.append(name)
426
+ logger.info(f"Trainable parameter names: {trainable_module_names}\n")
427
+
428
+ # transformer.enable_xformers_memory_efficient_attention() # use `tF.scaled_dot_product_attention` instead
429
+
430
+ # `accelerate` 0.16.0 will have better support for customized saving
431
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
432
+ # Create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
433
+ def save_model_hook(models, weights, output_dir):
434
+ if accelerator.is_main_process:
435
+ if args.use_ema:
436
+ ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema"))
437
+
438
+ for i, model in enumerate(models):
439
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
440
+
441
+ # Make sure to pop weight so that corresponding model is not saved again
442
+ if weights:
443
+ weights.pop()
444
+
445
+ def load_model_hook(models, input_dir):
446
+ if args.use_ema:
447
+ load_model = MyEMAModel.from_pretrained(os.path.join(input_dir, "transformer_ema"), PartCrafterDiTModel)
448
+ ema_transformer.load_state_dict(load_model.state_dict())
449
+ ema_transformer.to(accelerator.device)
450
+ del load_model
451
+
452
+ for _ in range(len(models)):
453
+ # Pop models so that they are not loaded again
454
+ model = models.pop()
455
+
456
+ # Load diffusers style into model
457
+ load_model = PartCrafterDiTModel.from_pretrained(input_dir, subfolder="transformer")
458
+ model.register_to_config(**load_model.config)
459
+
460
+ model.load_state_dict(load_model.state_dict())
461
+ del load_model
462
+
463
+ accelerator.register_save_state_pre_hook(save_model_hook)
464
+ accelerator.register_load_state_pre_hook(load_model_hook)
465
+
466
+ if configs["train"]["grad_checkpoint"]:
467
+ transformer.enable_gradient_checkpointing()
468
+
469
+ # Initialize the optimizer and learning rate scheduler
470
+ logger.info("Initializing the optimizer and learning rate scheduler...\n")
471
+ name_lr_mult = configs["train"].get("name_lr_mult", None)
472
+ lr_mult = configs["train"].get("lr_mult", 1.0)
473
+ params, params_lr_mult, names_lr_mult = [], [], []
474
+ for name, param in transformer.named_parameters():
475
+ if name_lr_mult is not None:
476
+ for k in name_lr_mult.split(","):
477
+ if k in name:
478
+ params_lr_mult.append(param)
479
+ names_lr_mult.append(name)
480
+ if name not in names_lr_mult:
481
+ params.append(param)
482
+ else:
483
+ params.append(param)
484
+ optimizer = get_optimizer(
485
+ params=[
486
+ {"params": params, "lr": configs["optimizer"]["lr"]},
487
+ {"params": params_lr_mult, "lr": configs["optimizer"]["lr"] * lr_mult}
488
+ ],
489
+ **configs["optimizer"]
490
+ )
491
+ if name_lr_mult is not None:
492
+ logger.info(f"Learning rate x [{lr_mult}] parameter names: {names_lr_mult}\n")
493
+
494
+ configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil(
495
+ len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps
496
+ configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting
497
+ if "num_warmup_steps" in configs["lr_scheduler"]:
498
+ configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting
499
+ lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"])
500
+ configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu
501
+ if "num_warmup_steps" in configs["lr_scheduler"]:
502
+ configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu
503
+
504
+ # Prepare everything with `accelerator`
505
+ transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader = accelerator.prepare(
506
+ transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader
507
+ )
508
+ # Set classes explicitly for everything
509
+ transformer: DistributedDataParallel
510
+ optimizer: AcceleratedOptimizer
511
+ lr_scheduler: AcceleratedScheduler
512
+ train_loader: DataLoaderShard
513
+ val_loader: DataLoaderShard
514
+ random_val_loader: DataLoaderShard
515
+
516
+ if args.use_ema:
517
+ ema_transformer.to(accelerator.device)
518
+
519
+ # For mixed precision training we cast all non-trainable weigths to half-precision
520
+ # as these weights are only used for inference, keeping weights in full precision is not required.
521
+ weight_dtype = torch.float32
522
+ if accelerator.mixed_precision == "fp16":
523
+ weight_dtype = torch.float16
524
+ elif accelerator.mixed_precision == "bf16":
525
+ weight_dtype = torch.bfloat16
526
+
527
+ # Move `vae` and `image_encoder_dinov2` to gpu and cast to `weight_dtype`
528
+ vae.to(accelerator.device, dtype=weight_dtype)
529
+ image_encoder_dinov2.to(accelerator.device, dtype=weight_dtype)
530
+
531
+ # Training configs after distribution and accumulation setup
532
+ updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
533
+ total_updated_steps = configs["lr_scheduler"]["total_steps"]
534
+ if args.max_train_steps is None:
535
+ args.max_train_steps = total_updated_steps
536
+ assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps
537
+ if accelerator.num_processes > 1 and accelerator.is_main_process:
538
+ print()
539
+ accelerator.wait_for_everyone()
540
+ logger.info(f"Total batch size: [{total_batch_size}]")
541
+ logger.info(f"Learning rate: [{configs['optimizer']['lr']}]")
542
+ logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]")
543
+ logger.info(f"Total epochs: [{configs['train']['epochs']}]")
544
+ logger.info(f"Total steps: [{total_updated_steps}]")
545
+ logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]")
546
+ logger.info(f"Steps for validation: [{len(val_loader)}]\n")
547
+
548
+ # (Optional) Load checkpoint
549
+ global_update_step = 0
550
+ if args.resume_from_iter is not None:
551
+ if args.resume_from_iter < 0:
552
+ args.resume_from_iter = int(sorted(os.listdir(ckpt_dir))[-1])
553
+ logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n")
554
+ # Load everything
555
+ if version.parse(torch.__version__) >= version.parse("2.4.0"):
556
+ torch.serialization.add_safe_globals([
557
+ int, list, dict,
558
+ defaultdict,
559
+ Any,
560
+ DictConfig, ListConfig, Metadata, ContainerMetadata, AnyNode
561
+ ]) # avoid deserialization error when loading optimizer state
562
+ accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False`
563
+ global_update_step = int(args.resume_from_iter)
564
+
565
+ # Save all experimental parameters and model architecture of this run to a file (args and configs)
566
+ if accelerator.is_main_process:
567
+ exp_params = save_experiment_params(args, configs, exp_dir)
568
+ save_model_architecture(accelerator.unwrap_model(transformer), exp_dir)
569
+
570
+ # WandB logger
571
+ if accelerator.is_main_process:
572
+ if args.offline_wandb:
573
+ os.environ["WANDB_MODE"] = "offline"
574
+ wandb.init(
575
+ project=PROJECT_NAME, name=args.tag,
576
+ config=exp_params, dir=exp_dir,
577
+ resume=True
578
+ )
579
+ # Wandb artifact for logging experiment information
580
+ arti_exp_info = wandb.Artifact(args.tag, type="exp_info")
581
+ arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml"))
582
+ arti_exp_info.add_file(os.path.join(exp_dir, "model.txt"))
583
+ arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training
584
+ wandb.log_artifact(arti_exp_info)
585
+
586
+ def get_sigmas(timesteps: Tensor, n_dim: int, dtype=torch.float32):
587
+ sigmas = noise_scheduler.sigmas.to(dtype=dtype, device=accelerator.device)
588
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
589
+ timesteps = timesteps.to(accelerator.device)
590
+
591
+ step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps]
592
+
593
+ sigma = sigmas[step_indices].flatten()
594
+ while len(sigma.shape) < n_dim:
595
+ sigma = sigma.unsqueeze(-1)
596
+ return sigma
597
+
598
+ # Start training
599
+ if accelerator.is_main_process:
600
+ print()
601
+ logger.info(f"Start training into {exp_dir}\n")
602
+ logger.logger.propagate = False # not propagate to the root logger (console)
603
+ progress_bar = tqdm(
604
+ range(total_updated_steps),
605
+ initial=global_update_step,
606
+ desc="Training",
607
+ ncols=125,
608
+ disable=not accelerator.is_main_process
609
+ )
610
+ for batch in yield_forever(train_loader):
611
+
612
+ if global_update_step == args.max_train_steps:
613
+ progress_bar.close()
614
+ logger.logger.propagate = True # propagate to the root logger (console)
615
+ if accelerator.is_main_process:
616
+ wandb.finish()
617
+ logger.info("Training finished!\n")
618
+ return
619
+
620
+ transformer.train()
621
+
622
+ with accelerator.accumulate(transformer):
623
+
624
+ images = batch["images"] # [N, H, W, 3]
625
+ with torch.no_grad():
626
+ images = feature_extractor_dinov2(images=images, return_tensors="pt").pixel_values
627
+ images = images.to(device=accelerator.device, dtype=weight_dtype)
628
+ with torch.no_grad():
629
+ image_embeds = image_encoder_dinov2(images).last_hidden_state
630
+ negative_image_embeds = torch.zeros_like(image_embeds)
631
+
632
+ part_surfaces = batch["part_surfaces"] # [N, P, 6]
633
+ part_surfaces = part_surfaces.to(device=accelerator.device, dtype=weight_dtype)
634
+
635
+ num_parts = batch["num_parts"] # [M, ] The shape of num_parts is not fixed
636
+ num_objects = num_parts.shape[0] # M
637
+
638
+ with torch.no_grad():
639
+ latents = vae.encode(
640
+ part_surfaces,
641
+ **configs["model"]["vae"]
642
+ ).latent_dist.sample()
643
+
644
+ noise = torch.randn_like(latents)
645
+ # For weighting schemes where we sample timesteps non-uniformly
646
+ u = compute_density_for_timestep_sampling(
647
+ weighting_scheme=configs["train"]["weighting_scheme"],
648
+ batch_size=num_objects,
649
+ logit_mean=configs["train"]["logit_mean"],
650
+ logit_std=configs["train"]["logit_std"],
651
+ mode_scale=configs["train"]["mode_scale"],
652
+ )
653
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
654
+ timesteps = noise_scheduler.timesteps[indices].to(accelerator.device) # [M, ]
655
+ # Repeat the timesteps for each part
656
+ timesteps = timesteps.repeat_interleave(num_parts) # [N, ]
657
+
658
+ sigmas = get_sigmas(timesteps, len(latents.shape), weight_dtype)
659
+ latent_model_input = noisy_latents = (1. - sigmas) * latents + sigmas * noise
660
+
661
+ if configs["train"]["cfg_dropout_prob"] > 0:
662
+ # We use the same dropout mask for the same part
663
+ dropout_mask = torch.rand(num_objects, device=accelerator.device) < configs["train"]["cfg_dropout_prob"] # [M, ]
664
+ dropout_mask = dropout_mask.repeat_interleave(num_parts) # [N, ]
665
+ if dropout_mask.any():
666
+ image_embeds[dropout_mask] = negative_image_embeds[dropout_mask]
667
+
668
+ model_pred = transformer(
669
+ hidden_states=latent_model_input,
670
+ timestep=timesteps,
671
+ encoder_hidden_states=image_embeds,
672
+ attention_kwargs={"num_parts": num_parts}
673
+ ).sample
674
+
675
+ if configs["train"]["training_objective"] == "x0": # Section 5 of https://arxiv.org/abs/2206.00364
676
+ model_pred = model_pred * (-sigmas) + noisy_latents # predicted x_0
677
+ target = latents
678
+ elif configs["train"]["training_objective"] == 'v': # flow matching
679
+ target = noise - latents
680
+ elif configs["train"]["training_objective"] == '-v': # reverse flow matching
681
+ # The training objective for TripoSG is the reverse of the flow matching objective.
682
+ # It uses "different directions", i.e., the negative velocity.
683
+ # This is probably a mistake in engineering, not very harmful.
684
+ # In TripoSG's rectified flow scheduler, prev_sample = sample + (sigma - sigma_next) * model_output
685
+ # See TripoSG's scheduler https://github.com/VAST-AI-Research/TripoSG/blob/main/triposg/schedulers/scheduling_rectified_flow.py#L296
686
+ # While in diffusers's flow matching scheduler, prev_sample = sample + (sigma_next - sigma) * model_output
687
+ # See https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L454
688
+ target = latents - noise
689
+ else:
690
+ raise ValueError(f"Unknown training objective [{configs['train']['training_objective']}]")
691
+
692
+ # For these weighting schemes use a uniform timestep sampling, so post-weight the loss
693
+ weighting = compute_loss_weighting_for_sd3(
694
+ configs["train"]["weighting_scheme"],
695
+ sigmas
696
+ )
697
+
698
+ loss = weighting * tF.mse_loss(model_pred.float(), target.float(), reduction="none")
699
+ loss = loss.mean(dim=list(range(1, len(loss.shape))))
700
+
701
+ # Backpropagate
702
+ accelerator.backward(loss.mean())
703
+ if accelerator.sync_gradients:
704
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
705
+
706
+ optimizer.step()
707
+ lr_scheduler.step()
708
+ optimizer.zero_grad()
709
+
710
+ # Checks if the accelerator has performed an optimization step behind the scenes
711
+ if accelerator.sync_gradients:
712
+ # Gather the losses across all processes for logging (if we use distributed training)
713
+ loss = accelerator.gather(loss.detach()).mean()
714
+
715
+ logs = {
716
+ "loss": loss.item(),
717
+ "lr": lr_scheduler.get_last_lr()[0]
718
+ }
719
+ if args.use_ema:
720
+ ema_transformer.step(transformer.parameters())
721
+ logs.update({"ema": ema_transformer.cur_decay_value})
722
+
723
+ progress_bar.set_postfix(**logs)
724
+ progress_bar.update(1)
725
+ global_update_step += 1
726
+
727
+ logger.info(
728
+ f"[{global_update_step:06d} / {total_updated_steps:06d}] " +
729
+ f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" +
730
+ f", ema: {logs['ema']:.4f}" if args.use_ema else ""
731
+ )
732
+
733
+ # Log the training progress
734
+ if (
735
+ global_update_step % configs["train"]["log_freq"] == 0
736
+ or global_update_step == 1
737
+ or global_update_step % updated_steps_per_epoch == 0 # last step of an epoch
738
+ ):
739
+ if accelerator.is_main_process:
740
+ wandb.log({
741
+ "training/loss": logs["loss"],
742
+ "training/lr": logs["lr"],
743
+ }, step=global_update_step)
744
+ if args.use_ema:
745
+ wandb.log({
746
+ "training/ema": logs["ema"]
747
+ }, step=global_update_step)
748
+
749
+ # Save checkpoint
750
+ if (
751
+ global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps
752
+ or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs
753
+ or global_update_step == total_updated_steps # 3. last step of an epoch
754
+ # or global_update_step == 1 # 4. first step
755
+ ):
756
+
757
+ gc.collect()
758
+ if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED:
759
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues
760
+ accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
761
+ elif accelerator.is_main_process:
762
+ accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
763
+ accelerator.wait_for_everyone() # ensure all processes have finished saving
764
+ gc.collect()
765
+
766
+ # Evaluate on the validation set
767
+ if args.max_val_steps > 0 and (
768
+ (global_update_step % configs["train"]["early_eval_freq"] == 0 and global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning
769
+ or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps
770
+ or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs
771
+ or global_update_step == total_updated_steps # 4. last step of an epoch
772
+ or global_update_step == 1 # 5. first step
773
+ ):
774
+
775
+ # Use EMA parameters for evaluation
776
+ if args.use_ema:
777
+ # Store the Transformer parameters temporarily and load the EMA parameters to perform inference
778
+ ema_transformer.store(transformer.parameters())
779
+ ema_transformer.copy_to(transformer.parameters())
780
+
781
+ transformer.eval()
782
+
783
+ log_validation(
784
+ val_loader, random_val_loader,
785
+ feature_extractor_dinov2, image_encoder_dinov2,
786
+ vae, transformer,
787
+ global_update_step, eval_dir,
788
+ accelerator, logger,
789
+ args, configs
790
+ )
791
+
792
+ if args.use_ema:
793
+ # Switch back to the original Transformer parameters
794
+ ema_transformer.restore(transformer.parameters())
795
+
796
+ torch.cuda.empty_cache()
797
+ gc.collect()
798
+
799
+ @torch.no_grad()
800
+ def log_validation(
801
+ dataloader, random_dataloader,
802
+ feature_extractor_dinov2, image_encoder_dinov2,
803
+ vae, transformer,
804
+ global_step, eval_dir,
805
+ accelerator, logger,
806
+ args, configs
807
+ ):
808
+
809
+ val_noise_scheduler = RectifiedFlowScheduler.from_pretrained(
810
+ configs["model"]["pretrained_model_name_or_path"],
811
+ subfolder="scheduler"
812
+ )
813
+
814
+ pipeline = PartCrafterPipeline(
815
+ vae=vae,
816
+ transformer=accelerator.unwrap_model(transformer),
817
+ scheduler=val_noise_scheduler,
818
+ feature_extractor_dinov2=feature_extractor_dinov2,
819
+ image_encoder_dinov2=image_encoder_dinov2,
820
+ )
821
+
822
+ pipeline.set_progress_bar_config(disable=True)
823
+ # pipeline.enable_xformers_memory_efficient_attention()
824
+
825
+ if args.seed >= 0:
826
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
827
+ else:
828
+ generator = None
829
+
830
+
831
+ val_progress_bar = tqdm(
832
+ range(len(dataloader)) if args.max_val_steps is None else range(args.max_val_steps),
833
+ desc=f"Validation [{global_step:06d}]",
834
+ ncols=125,
835
+ disable=not accelerator.is_main_process
836
+ )
837
+
838
+ medias_dictlist, metrics_dictlist = defaultdict(list), defaultdict(list)
839
+
840
+ val_dataloder, random_val_dataloader = yield_forever(dataloader), yield_forever(random_dataloader)
841
+ val_step = 0
842
+ while val_step < args.max_val_steps:
843
+
844
+ if val_step < args.max_val_steps // 2:
845
+ # fix the first half
846
+ batch = next(val_dataloder)
847
+ else:
848
+ # randomly sample the next batch
849
+ batch = next(random_val_dataloader)
850
+
851
+ images = batch["images"]
852
+ if len(images.shape) == 5:
853
+ images = images[0] # (1, N, H, W, 3) -> (N, H, W, 3)
854
+ images = [Image.fromarray(image) for image in images.cpu().numpy()]
855
+ part_surfaces = batch["part_surfaces"].cpu().numpy()
856
+ if len(part_surfaces.shape) == 4:
857
+ part_surfaces = part_surfaces[0] # (1, N, P, 6) -> (N, P, 6)
858
+
859
+ N = len(images)
860
+
861
+ val_progress_bar.set_postfix(
862
+ {"num_parts": N}
863
+ )
864
+
865
+ with torch.autocast("cuda", torch.float16):
866
+ for guidance_scale in sorted(args.val_guidance_scales):
867
+ pred_part_meshes = pipeline(
868
+ images,
869
+ num_inference_steps=configs['val']['num_inference_steps'],
870
+ num_tokens=configs['model']['vae']['num_tokens'],
871
+ guidance_scale=guidance_scale,
872
+ attention_kwargs={"num_parts": N},
873
+ generator=generator,
874
+ max_num_expanded_coords=configs['val']['max_num_expanded_coords'],
875
+ use_flash_decoder=configs['val']['use_flash_decoder'],
876
+ ).meshes
877
+
878
+ # Save the generated meshes
879
+ if accelerator.is_main_process:
880
+ local_eval_dir = os.path.join(eval_dir, f"{global_step:06d}", f"guidance_scale_{guidance_scale:.1f}")
881
+ os.makedirs(local_eval_dir, exist_ok=True)
882
+ rendered_images_list, rendered_normals_list = [], []
883
+ # 1. save the gt image
884
+ images[0].save(os.path.join(local_eval_dir, f"{val_step:04d}.png"))
885
+ # 2. save the generated part meshes
886
+ for n in range(N):
887
+ if pred_part_meshes[n] is None:
888
+ # If the generated mesh is None (decoing error), use a dummy mesh
889
+ pred_part_meshes[n] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
890
+ pred_part_meshes[n].export(os.path.join(local_eval_dir, f"{val_step:04d}_{n:02d}.glb"))
891
+ # 3. render the generated mesh and save the rendered images
892
+ pred_mesh = get_colored_mesh_composition(pred_part_meshes)
893
+ rendered_images: List[Image.Image] = render_views_around_mesh(
894
+ pred_mesh,
895
+ num_views=configs['val']['rendering']['num_views'],
896
+ radius=configs['val']['rendering']['radius'],
897
+ )
898
+ rendered_normals: List[Image.Image] = render_normal_views_around_mesh(
899
+ pred_mesh,
900
+ num_views=configs['val']['rendering']['num_views'],
901
+ radius=configs['val']['rendering']['radius'],
902
+ )
903
+ export_renderings(
904
+ rendered_images,
905
+ os.path.join(local_eval_dir, f"{val_step:04d}.gif"),
906
+ fps=configs['val']['rendering']['fps']
907
+ )
908
+ export_renderings(
909
+ rendered_normals,
910
+ os.path.join(local_eval_dir, f"{val_step:04d}_normals.gif"),
911
+ fps=configs['val']['rendering']['fps']
912
+ )
913
+ rendered_images_list.append(rendered_images)
914
+ rendered_normals_list.append(rendered_normals)
915
+
916
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/gt_image"] += [images[0]] # List[Image.Image] TODO: support batch size > 1
917
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_images"] += rendered_images_list # List[List[Image.Image]]
918
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_normals"] += rendered_normals_list # List[List[Image.Image]]
919
+
920
+ ################################ Compute generation metrics ################################
921
+
922
+ parts_chamfer_distances, parts_f_scores = [], []
923
+
924
+ for n in range(N):
925
+ # gt_part_surface = part_surfaces[n]
926
+ # pred_part_mesh = pred_part_meshes[n]
927
+ # if pred_part_mesh is None:
928
+ # # If the generated mesh is None (decoing error), use a dummy mesh
929
+ # pred_part_mesh = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
930
+ # part_cd, part_f = compute_cd_and_f_score_in_training(
931
+ # gt_part_surface, pred_part_mesh,
932
+ # num_samples=configs['val']['metric']['cd_num_samples'],
933
+ # threshold=configs['val']['metric']['f1_score_threshold'],
934
+ # metric=configs['val']['metric']['cd_metric']
935
+ # )
936
+ # # avoid nan
937
+ # part_cd = configs['val']['metric']['default_cd'] if np.isnan(part_cd) else part_cd
938
+ # part_f = configs['val']['metric']['default_f1'] if np.isnan(part_f) else part_f
939
+ # parts_chamfer_distances.append(part_cd)
940
+ # parts_f_scores.append(part_f)
941
+
942
+ # TODO: Fix this
943
+ # Disable chamfer distance and F1 score for now
944
+ parts_chamfer_distances.append(0.0)
945
+ parts_f_scores.append(0.0)
946
+
947
+ parts_chamfer_distances = torch.tensor(parts_chamfer_distances, device=accelerator.device)
948
+ parts_f_scores = torch.tensor(parts_f_scores, device=accelerator.device)
949
+
950
+ metrics_dictlist[f"parts_chamfer_distance_cfg{guidance_scale:.1f}"].append(parts_chamfer_distances.mean())
951
+ metrics_dictlist[f"parts_f_score_cfg{guidance_scale:.1f}"].append(parts_f_scores.mean())
952
+
953
+ # Only log the last (biggest) cfg metrics in the progress bar
954
+ val_logs = {
955
+ "parts_chamfer_distance": parts_chamfer_distances.mean().item(),
956
+ "parts_f_score": parts_f_scores.mean().item(),
957
+ }
958
+ val_progress_bar.set_postfix(**val_logs)
959
+ logger.info(
960
+ f"Validation [{val_step:02d}/{args.max_val_steps:02d}] " +
961
+ f"parts_chamfer_distance: {val_logs['parts_chamfer_distance']:.4f}, parts_f_score: {val_logs['parts_f_score']:.4f}"
962
+ )
963
+ logger.info(
964
+ f"parts_chamfer_distances: {[f'{x:.4f}' for x in parts_chamfer_distances.tolist()]}"
965
+ )
966
+ logger.info(
967
+ f"parts_f_scores: {[f'{x:.4f}' for x in parts_f_scores.tolist()]}"
968
+ )
969
+ val_step += 1
970
+ val_progress_bar.update(1)
971
+
972
+ val_progress_bar.close()
973
+
974
+ if accelerator.is_main_process:
975
+ for key, value in medias_dictlist.items():
976
+ if isinstance(value[0], Image.Image): # assuming gt_image
977
+ image_grid = make_grid_for_images_or_videos(
978
+ value,
979
+ nrow=configs['val']['nrow'],
980
+ return_type='pil',
981
+ )
982
+ image_grid.save(os.path.join(eval_dir, f"{global_step:06d}", f"{key}.png"))
983
+ wandb.log({f"validation/{key}": wandb.Image(image_grid)}, step=global_step)
984
+ else: # assuming pred_rendered_images or pred_rendered_normals
985
+ image_grids = make_grid_for_images_or_videos(
986
+ value,
987
+ nrow=configs['val']['nrow'],
988
+ return_type='ndarray',
989
+ )
990
+ wandb.log({
991
+ f"validation/{key}": wandb.Video(
992
+ image_grids,
993
+ fps=configs['val']['rendering']['fps'],
994
+ format="gif"
995
+ )}, step=global_step)
996
+ image_grids = [Image.fromarray(image_grid.transpose(1, 2, 0)) for image_grid in image_grids]
997
+ export_renderings(
998
+ image_grids,
999
+ os.path.join(eval_dir, f"{global_step:06d}", f"{key}.gif"),
1000
+ fps=configs['val']['rendering']['fps']
1001
+ )
1002
+
1003
+ for k, v in metrics_dictlist.items():
1004
+ wandb.log({f"validation/{k}": torch.tensor(v).mean().item()}, step=global_step)
1005
+
1006
+ if __name__ == "__main__":
1007
+ main()
src/utils/data_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import os
4
+ import numpy as np
5
+ import trimesh
6
+ import torch
7
+
8
+ def normalize_mesh(
9
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
10
+ scale: float = 2.0,
11
+ ):
12
+ # if not isinstance(mesh, trimesh.Trimesh) and not isinstance(mesh, trimesh.Scene):
13
+ # raise ValueError("Input mesh is not a trimesh.Trimesh or trimesh.Scene object.")
14
+ bbox = mesh.bounding_box
15
+ translation = -bbox.centroid
16
+ scale = scale / bbox.primitive.extents.max()
17
+ mesh.apply_translation(translation)
18
+ mesh.apply_scale(scale)
19
+ return mesh
20
+
21
+ def remove_overlapping_vertices(mesh: trimesh.Trimesh, reserve_material: bool = False):
22
+ if not isinstance(mesh, trimesh.Trimesh):
23
+ raise ValueError("Input mesh is not a trimesh.Trimesh object.")
24
+ vertices = mesh.vertices
25
+ faces = mesh.faces
26
+ unique_vertices, index_map, inverse_map = np.unique(
27
+ vertices, axis=0, return_index=True, return_inverse=True
28
+ )
29
+ clean_faces = inverse_map[faces]
30
+ clean_mesh = trimesh.Trimesh(vertices=unique_vertices, faces=clean_faces, process=True)
31
+ if reserve_material:
32
+ uv = mesh.visual.uv
33
+ material = mesh.visual.material
34
+ clean_uv = uv[index_map]
35
+ clean_visual = trimesh.visual.TextureVisuals(uv=clean_uv, material=material)
36
+ clean_mesh.visual = clean_visual
37
+ return clean_mesh
38
+
39
+ RGB = [
40
+ (82, 170, 220),
41
+ (215, 91, 78),
42
+ (45, 136, 117),
43
+ (247, 172, 83),
44
+ (124, 121, 121),
45
+ (127, 171, 209),
46
+ (243, 152, 101),
47
+ (145, 204, 192),
48
+ (150, 59, 121),
49
+ (181, 206, 78),
50
+ (189, 119, 149),
51
+ (199, 193, 222),
52
+ (200, 151, 54),
53
+ (236, 110, 102),
54
+ (238, 182, 212),
55
+ ]
56
+
57
+
58
+ def get_colored_mesh_composition(
59
+ meshes: Union[List[trimesh.Trimesh], trimesh.Scene],
60
+ is_random: bool = True,
61
+ is_sorted: bool = False,
62
+ RGB: List[Tuple] = RGB
63
+ ):
64
+ if isinstance(meshes, trimesh.Scene):
65
+ meshes = meshes.dump()
66
+ if is_sorted:
67
+ volumes = []
68
+ for mesh in meshes:
69
+ try:
70
+ volume = mesh.volume
71
+ except:
72
+ volume = 0.0
73
+ volumes.append(volume)
74
+ # sort by volume from large to small
75
+ meshes = [x for _, x in sorted(zip(volumes, meshes), key=lambda pair: pair[0], reverse=True)]
76
+ colored_scene = trimesh.Scene()
77
+ for idx, mesh in enumerate(meshes):
78
+ if is_random:
79
+ color = (np.random.rand(3) * 256).astype(int)
80
+ else:
81
+ color = np.array(RGB[idx % len(RGB)])
82
+ mesh.visual = trimesh.visual.ColorVisuals(
83
+ mesh=mesh,
84
+ vertex_colors=color,
85
+ )
86
+ colored_scene.add_geometry(mesh)
87
+ return colored_scene
88
+
89
+ def mesh_to_surface(
90
+ mesh: trimesh.Trimesh,
91
+ num_pc: int = 204800,
92
+ clip_to_num_vertices: bool = False,
93
+ return_dict: bool = False,
94
+ ):
95
+ # if not isinstance(mesh, trimesh.Trimesh):
96
+ # raise ValueError("mesh must be a trimesh.Trimesh object")
97
+ if clip_to_num_vertices:
98
+ num_pc = min(num_pc, mesh.vertices.shape[0])
99
+ points, face_indices = mesh.sample(num_pc, return_index=True)
100
+ normals = mesh.face_normals[face_indices]
101
+ if return_dict:
102
+ return {
103
+ "surface_points": points,
104
+ "surface_normals": normals,
105
+ }
106
+ return points, normals
107
+
108
+ def scene_to_parts(
109
+ mesh: trimesh.Scene,
110
+ normalize: bool = True,
111
+ scale: float = 2.0,
112
+ num_part_pc: int = 204800,
113
+ clip_to_num_part_vertices: bool = False,
114
+ return_type: Literal["mesh", "point"] = "mesh",
115
+ ) -> Union[List[trimesh.Geometry], List[Dict[str, np.ndarray]]]:
116
+ if not isinstance(mesh, trimesh.Scene):
117
+ raise ValueError("mesh must be a trimesh.Scene object")
118
+ if normalize:
119
+ mesh = normalize_mesh(mesh, scale=scale)
120
+ parts: List[trimesh.Geometry] = mesh.dump()
121
+ if return_type == "point":
122
+ datas: List[Dict[str, np.ndarray]] = []
123
+ for geom in parts:
124
+ data = mesh_to_surface(
125
+ geom,
126
+ num_pc=num_part_pc,
127
+ clip_to_num_vertices=clip_to_num_part_vertices,
128
+ return_dict=True,
129
+ )
130
+ datas.append(data)
131
+ return datas
132
+ elif return_type == "mesh":
133
+ return parts
134
+ else:
135
+ raise ValueError("return_type must be 'mesh' or 'point'")
136
+
137
+ def get_center(mesh: trimesh.Trimesh, method: Literal['mass', 'bbox']):
138
+ if method == 'mass':
139
+ return mesh.center_mass
140
+ elif method =='bbox':
141
+ return mesh.bounding_box.centroid
142
+ else:
143
+ raise ValueError('type must be mass or bbox')
144
+
145
+ def get_direction(vector: np.ndarray):
146
+ return vector / np.linalg.norm(vector)
147
+
148
+ def move_mesh_by_center(mesh: trimesh.Trimesh, scale: float, method: Literal['mass', 'bbox'] = 'mass'):
149
+ offset = scale - 1
150
+ center = get_center(mesh, method)
151
+ direction = get_direction(center)
152
+ translation = direction * offset
153
+ mesh = mesh.copy()
154
+ mesh.apply_translation(translation)
155
+ return mesh
156
+
157
+ def move_meshes_by_center(meshes: Union[List[trimesh.Trimesh], trimesh.Scene], scale: float):
158
+ if isinstance(meshes, trimesh.Scene):
159
+ meshes = meshes.dump()
160
+ moved_meshes = []
161
+ for mesh in meshes:
162
+ moved_mesh = move_mesh_by_center(mesh, scale)
163
+ moved_meshes.append(moved_mesh)
164
+ moved_meshes = trimesh.Scene(moved_meshes)
165
+ return moved_meshes
166
+
167
+ def get_series_splited_meshes(meshes: List[trimesh.Trimesh], scale: float, num_steps: int) -> List[trimesh.Scene]:
168
+ series_meshes = []
169
+ for i in range(num_steps):
170
+ temp_scale = 1 + (scale - 1) * i / (num_steps - 1)
171
+ temp_meshes = move_meshes_by_center(meshes, temp_scale)
172
+ series_meshes.append(temp_meshes)
173
+ return series_meshes
174
+
175
+ def load_surface(data, num_pc=204800):
176
+
177
+ surface = data["surface_points"] # Nx3
178
+ normal = data["surface_normals"] # Nx3
179
+
180
+ rng = np.random.default_rng()
181
+ ind = rng.choice(surface.shape[0], num_pc, replace=False)
182
+ surface = torch.FloatTensor(surface[ind])
183
+ normal = torch.FloatTensor(normal[ind])
184
+ surface = torch.cat([surface, normal], dim=-1)
185
+
186
+ return surface
187
+
188
+ def load_surfaces(surfaces, num_pc=204800):
189
+ surfaces = [load_surface(surface, num_pc) for surface in surfaces]
190
+ surfaces = torch.stack(surfaces, dim=0)
191
+ return surfaces
src/utils/image_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ from skimage.morphology import remove_small_objects
4
+ from skimage.measure import label
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ from torchvision import transforms
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+
13
+ def find_bounding_box(gray_image):
14
+ _, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
15
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
16
+ max_contour = max(contours, key=cv2.contourArea)
17
+ x, y, w, h = cv2.boundingRect(max_contour)
18
+ return x, y, w, h
19
+
20
+ def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1, device='cuda'):
21
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
22
+ if img is None:
23
+ return f"invalid image path {img_path}"
24
+
25
+ def is_valid_alpha(alpha, min_ratio = 0.01):
26
+ bins = 20
27
+ if isinstance(alpha, np.ndarray):
28
+ hist = cv2.calcHist([alpha], [0], None, [bins], [0, 256])
29
+ else:
30
+ hist = torch.histc(alpha, bins=bins, min=0, max=1)
31
+ min_hist_val = alpha.shape[0] * alpha.shape[1] * min_ratio
32
+ return hist[0] >= min_hist_val and hist[-1] >= min_hist_val
33
+
34
+ def rmbg(image: torch.Tensor) -> torch.Tensor:
35
+ image = TF.normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0]).unsqueeze(0)
36
+ result=rmbg_net(image)
37
+ return result[0][0]
38
+
39
+ if len(img.shape) == 2:
40
+ num_channels = 1
41
+ else:
42
+ num_channels = img.shape[2]
43
+
44
+ # check if too large
45
+ height, width = img.shape[:2]
46
+ if height > width:
47
+ scale = 2000 / height
48
+ else:
49
+ scale = 2000 / width
50
+ if scale < 1:
51
+ new_size = (int(width * scale), int(height * scale))
52
+ img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
53
+
54
+ if img.dtype != 'uint8':
55
+ img = (img * (255. / np.iinfo(img.dtype).max)).astype(np.uint8)
56
+
57
+ rgb_image = None
58
+ alpha = None
59
+
60
+ if num_channels == 1:
61
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
62
+ elif num_channels == 3:
63
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
64
+ elif num_channels == 4:
65
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
66
+
67
+ b, g, r, alpha = cv2.split(img)
68
+ if not is_valid_alpha(alpha):
69
+ alpha = None
70
+ else:
71
+ alpha_gpu = torch.from_numpy(alpha).unsqueeze(0).to(device).float() / 255.
72
+ else:
73
+ return f"invalid image: channels {num_channels}"
74
+
75
+ rgb_image_gpu = torch.from_numpy(rgb_image).to(device).float().permute(2, 0, 1) / 255.
76
+ if alpha is None:
77
+ resize_transform = transforms.Resize((384, 384), antialias=True)
78
+ rgb_image_resized = resize_transform(rgb_image_gpu)
79
+ normalize_image = rgb_image_resized * 2 - 1
80
+
81
+ mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
82
+ resize_transform = transforms.Resize((1024, 1024), antialias=True)
83
+ rgb_image_resized = resize_transform(rgb_image_gpu)
84
+ max_value = rgb_image_resized.flatten().max()
85
+ if max_value < 1e-3:
86
+ return "invalid image: pure black image"
87
+ normalize_image = rgb_image_resized / max_value - mean_color
88
+ normalize_image = normalize_image.unsqueeze(0)
89
+ resize_transform = transforms.Resize((rgb_image_gpu.shape[1], rgb_image_gpu.shape[2]), antialias=True)
90
+
91
+ # seg from rmbg
92
+ alpha_gpu_rmbg = rmbg(rgb_image_resized)
93
+ alpha_gpu_rmbg = alpha_gpu_rmbg.squeeze(0)
94
+ alpha_gpu_rmbg = resize_transform(alpha_gpu_rmbg)
95
+ ma, mi = alpha_gpu_rmbg.max(), alpha_gpu_rmbg.min()
96
+ alpha_gpu_rmbg = (alpha_gpu_rmbg - mi) / (ma - mi)
97
+
98
+ alpha_gpu = alpha_gpu_rmbg
99
+
100
+ alpha_gpu_tmp = alpha_gpu * 255
101
+ alpha = alpha_gpu_tmp.to(torch.uint8).squeeze().cpu().numpy()
102
+
103
+ _, alpha = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
104
+ labeled_alpha = label(alpha)
105
+ cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
106
+ cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
107
+ alpha = cleaned_alpha * 255
108
+ alpha_gpu = torch.from_numpy(cleaned_alpha).to(device).float().unsqueeze(0)
109
+ x, y, w, h = find_bounding_box(alpha)
110
+
111
+ # If alpha is provided, the bounds of all foreground are used
112
+ else:
113
+ rows, cols = np.where(alpha > 0)
114
+ if rows.size > 0 and cols.size > 0:
115
+ x_min = np.min(cols)
116
+ y_min = np.min(rows)
117
+ x_max = np.max(cols)
118
+ y_max = np.max(rows)
119
+
120
+ width = x_max - x_min + 1
121
+ height = y_max - y_min + 1
122
+ x, y, w, h = x_min, y_min, width, height
123
+
124
+ if np.all(alpha==0):
125
+ raise ValueError(f"input image too small")
126
+
127
+ bg_gray = bg_color[0]
128
+ bg_color = torch.from_numpy(bg_color).float().to(device).repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
129
+ rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
130
+ padding_size = [0] * 6
131
+ if w > h:
132
+ padding_size[0] = int(w * padding_ratio)
133
+ padding_size[2] = int(padding_size[0] + (w - h) / 2)
134
+ else:
135
+ padding_size[2] = int(h * padding_ratio)
136
+ padding_size[0] = int(padding_size[2] + (h - w) / 2)
137
+ padding_size[1] = padding_size[0]
138
+ padding_size[3] = padding_size[2]
139
+ padded_tensor = F.pad(rgb_image_gpu[:, y:(y+h), x:(x+w)], pad=tuple(padding_size), mode='constant', value=bg_gray)
140
+
141
+ return padded_tensor
142
+
143
+ def prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=None, padding_ratio=0.1, device='cuda'):
144
+ if os.path.isfile(image_path):
145
+ img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net, padding_ratio=padding_ratio, device=device)
146
+ img_np = img_tensor.permute(1,2,0).cpu().numpy()
147
+ img_pil = Image.fromarray((img_np*255).astype(np.uint8))
148
+
149
+ return img_pil
150
+ else:
151
+ raise ValueError(f"Invalid image path: {image_path}")
src/utils/inference_utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import scipy.ndimage
7
+ from skimage import measure
8
+ from einops import repeat
9
+ import torch.nn.functional as F
10
+
11
+ def generate_dense_grid_points(
12
+ bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
13
+ ):
14
+ length = bbox_max - bbox_min
15
+ num_cells = np.exp2(octree_depth)
16
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
17
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
18
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
19
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
20
+ xyz = np.stack((xs, ys, zs), axis=-1)
21
+ xyz = xyz.reshape(-1, 3)
22
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
23
+
24
+ return xyz, grid_size, length
25
+
26
+ def generate_dense_grid_points_gpu(
27
+ bbox_min: torch.Tensor,
28
+ bbox_max: torch.Tensor,
29
+ octree_depth: int,
30
+ indexing: str = "ij",
31
+ dtype: torch.dtype = torch.float16
32
+ ):
33
+ length = bbox_max - bbox_min
34
+ num_cells = 2 ** octree_depth
35
+ device = bbox_min.device
36
+
37
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(num_cells), dtype=dtype, device=device)
38
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(num_cells), dtype=dtype, device=device)
39
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(num_cells), dtype=dtype, device=device)
40
+
41
+ xs, ys, zs = torch.meshgrid(x, y, z, indexing=indexing)
42
+ xyz = torch.stack((xs, ys, zs), dim=-1)
43
+ xyz = xyz.view(-1, 3)
44
+ grid_size = [int(num_cells), int(num_cells), int(num_cells)]
45
+
46
+ return xyz, grid_size, length
47
+
48
+ def find_mesh_grid_coordinates_fast_gpu(
49
+ occupancy_grid,
50
+ n_limits=-1
51
+ ):
52
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
53
+ occupied = core_grid > 0
54
+
55
+ neighbors_unoccupied = (
56
+ (occupancy_grid[:-2, :-2, :-2] < 0)
57
+ | (occupancy_grid[:-2, :-2, 1:-1] < 0)
58
+ | (occupancy_grid[:-2, :-2, 2:] < 0) # x-1, y-1, z-1/0/1
59
+ | (occupancy_grid[:-2, 1:-1, :-2] < 0)
60
+ | (occupancy_grid[:-2, 1:-1, 1:-1] < 0)
61
+ | (occupancy_grid[:-2, 1:-1, 2:] < 0) # x-1, y0, z-1/0/1
62
+ | (occupancy_grid[:-2, 2:, :-2] < 0)
63
+ | (occupancy_grid[:-2, 2:, 1:-1] < 0)
64
+ | (occupancy_grid[:-2, 2:, 2:] < 0) # x-1, y+1, z-1/0/1
65
+ | (occupancy_grid[1:-1, :-2, :-2] < 0)
66
+ | (occupancy_grid[1:-1, :-2, 1:-1] < 0)
67
+ | (occupancy_grid[1:-1, :-2, 2:] < 0) # x0, y-1, z-1/0/1
68
+ | (occupancy_grid[1:-1, 1:-1, :-2] < 0)
69
+ | (occupancy_grid[1:-1, 1:-1, 2:] < 0) # x0, y0, z-1/1
70
+ | (occupancy_grid[1:-1, 2:, :-2] < 0)
71
+ | (occupancy_grid[1:-1, 2:, 1:-1] < 0)
72
+ | (occupancy_grid[1:-1, 2:, 2:] < 0) # x0, y+1, z-1/0/1
73
+ | (occupancy_grid[2:, :-2, :-2] < 0)
74
+ | (occupancy_grid[2:, :-2, 1:-1] < 0)
75
+ | (occupancy_grid[2:, :-2, 2:] < 0) # x+1, y-1, z-1/0/1
76
+ | (occupancy_grid[2:, 1:-1, :-2] < 0)
77
+ | (occupancy_grid[2:, 1:-1, 1:-1] < 0)
78
+ | (occupancy_grid[2:, 1:-1, 2:] < 0) # x+1, y0, z-1/0/1
79
+ | (occupancy_grid[2:, 2:, :-2] < 0)
80
+ | (occupancy_grid[2:, 2:, 1:-1] < 0)
81
+ | (occupancy_grid[2:, 2:, 2:] < 0) # x+1, y+1, z-1/0/1
82
+ )
83
+ core_mesh_coords = torch.nonzero(occupied & neighbors_unoccupied, as_tuple=False) + 1
84
+
85
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
86
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
87
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
88
+ core_mesh_coords = core_mesh_coords[ind]
89
+
90
+ return core_mesh_coords
91
+
92
+ def find_candidates_band(
93
+ occupancy_grid: torch.Tensor,
94
+ band_threshold: float,
95
+ n_limits: int = -1
96
+ ) -> torch.Tensor:
97
+ """
98
+ Returns the coordinates of all voxels in the occupancy_grid where |value| < band_threshold.
99
+
100
+ Args:
101
+ occupancy_grid (torch.Tensor): A 3D tensor of SDF values.
102
+ band_threshold (float): The threshold below which |SDF| must be to include the voxel.
103
+ n_limits (int): Maximum number of points to return (-1 for no limit)
104
+
105
+ Returns:
106
+ torch.Tensor: A 2D tensor of coordinates (N x 3) where each row is [x, y, z].
107
+ """
108
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
109
+ # logits to sdf
110
+ core_grid = torch.sigmoid(core_grid) * 2 - 1
111
+ # Create a boolean mask for all cells in the band
112
+ in_band = torch.abs(core_grid) < band_threshold
113
+
114
+ # Get coordinates of all voxels in the band
115
+ core_mesh_coords = torch.nonzero(in_band, as_tuple=False) + 1
116
+
117
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
118
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
119
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
120
+ core_mesh_coords = core_mesh_coords[ind]
121
+
122
+ return core_mesh_coords
123
+
124
+ def expand_edge_region_fast(edge_coords, grid_size, dtype):
125
+ expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=dtype, requires_grad=False)
126
+ expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
127
+ if grid_size < 512:
128
+ kernel_size = 5
129
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=2).squeeze()
130
+ else:
131
+ kernel_size = 3
132
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=1).squeeze()
133
+ expanded_coords_low_res = torch.nonzero(pooled_tensor, as_tuple=False).to(torch.int16)
134
+
135
+ expanded_coords_high_res = torch.stack([
136
+ torch.cat((expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1)),
137
+ torch.cat((expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2+1, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2 + 1)),
138
+ torch.cat((expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1))
139
+ ], dim=1)
140
+
141
+ return expanded_coords_high_res
142
+
143
+ def zoom_block(block, scale_factor, order=3):
144
+ block = block.astype(np.float32)
145
+ return scipy.ndimage.zoom(block, scale_factor, order=order)
146
+
147
+ def parallel_zoom(occupancy_grid, scale_factor):
148
+ result = torch.nn.functional.interpolate(occupancy_grid.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor)
149
+ return result.squeeze(0).squeeze(0)
150
+
151
+
152
+ @torch.no_grad()
153
+ def hierarchical_extract_geometry(
154
+ geometric_func: Callable,
155
+ device: torch.device,
156
+ dtype: torch.dtype,
157
+ bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
158
+ dense_octree_depth: int = 8,
159
+ hierarchical_octree_depth: int = 9,
160
+ max_num_expanded_coords: int = 1e8,
161
+ verbose: bool = False,
162
+ ):
163
+ """
164
+ Args:
165
+ geometric_func:
166
+ device:
167
+ bounds:
168
+ dense_octree_depth:
169
+ hierarchical_octree_depth:
170
+ Returns:
171
+ """
172
+ if isinstance(bounds, float):
173
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
174
+
175
+ bbox_min = torch.tensor(bounds[0:3]).to(device)
176
+ bbox_max = torch.tensor(bounds[3:6]).to(device)
177
+ bbox_size = bbox_max - bbox_min
178
+
179
+ xyz_samples, grid_size, length = generate_dense_grid_points_gpu(
180
+ bbox_min=bbox_min,
181
+ bbox_max=bbox_max,
182
+ octree_depth=dense_octree_depth,
183
+ indexing="ij",
184
+ dtype=dtype
185
+ )
186
+
187
+ if verbose:
188
+ print(f'step 1 query num: {xyz_samples.shape[0]}')
189
+ grid_logits = geometric_func(xyz_samples.unsqueeze(0)).to(dtype).view(grid_size[0], grid_size[1], grid_size[2])
190
+ # print(f'step 1 grid_logits shape: {grid_logits.shape}')
191
+ for i in range(hierarchical_octree_depth - dense_octree_depth):
192
+ curr_octree_depth = dense_octree_depth + i + 1
193
+ # upsample
194
+ grid_size = 2**curr_octree_depth
195
+ normalize_offset = grid_size / 2
196
+ high_res_occupancy = parallel_zoom(grid_logits, 2).to(dtype)
197
+
198
+ band_threshold = 1.0
199
+ edge_coords = find_candidates_band(grid_logits, band_threshold)
200
+ expanded_coords = expand_edge_region_fast(edge_coords, grid_size=int(grid_size/2), dtype=dtype).to(dtype)
201
+ if verbose:
202
+ print(f'step {i+2} query num: {len(expanded_coords)}')
203
+ if max_num_expanded_coords > 0 and len(expanded_coords) > max_num_expanded_coords:
204
+ raise ValueError(f"expanded_coords is too large, {len(expanded_coords)} > {max_num_expanded_coords}")
205
+ expanded_coords_norm = (expanded_coords - normalize_offset) * (abs(bounds[0]) / normalize_offset)
206
+
207
+ all_logits = None
208
+
209
+ all_logits = geometric_func(expanded_coords_norm.unsqueeze(0)).to(dtype)
210
+ all_logits = torch.cat([expanded_coords_norm, all_logits[0]], dim=1)
211
+ # print("all logits shape = ", all_logits.shape)
212
+
213
+ indices = all_logits[..., :3]
214
+ indices = indices * (normalize_offset / abs(bounds[0])) + normalize_offset
215
+ indices = indices.type(torch.IntTensor)
216
+ values = all_logits[:, 3]
217
+ # breakpoint()
218
+ high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
219
+ grid_logits = high_res_occupancy
220
+ # torch.cuda.empty_cache()
221
+
222
+ if verbose:
223
+ print("final grids shape = ", grid_logits.shape)
224
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits.float().cpu().numpy(), 0, method="lewiner")
225
+ vertices = vertices / (2**hierarchical_octree_depth) * bbox_size.cpu().numpy() + bbox_min.cpu().numpy()
226
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
227
+
228
+ return mesh_v_f
src/utils/metric_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import trimesh
4
+ import numpy as np
5
+ from sklearn.neighbors import NearestNeighbors
6
+
7
+ def sample_from_mesh(
8
+ mesh: trimesh.Trimesh,
9
+ num_samples: Optional[int] = 10000,
10
+ ):
11
+ if num_samples is None:
12
+ return mesh.vertices
13
+ else:
14
+ return mesh.sample(num_samples)
15
+
16
+ def sample_two_meshes(
17
+ mesh1: trimesh.Trimesh,
18
+ mesh2: trimesh.Trimesh,
19
+ num_samples: Optional[int] = 10000,
20
+ ):
21
+ points1 = sample_from_mesh(mesh1, num_samples)
22
+ points2 = sample_from_mesh(mesh2, num_samples)
23
+ return points1, points2
24
+
25
+ def compute_nearest_distance(
26
+ points1: np.ndarray,
27
+ points2: np.ndarray,
28
+ metric: str = 'l2'
29
+ ) -> np.ndarray:
30
+ # Compute nearest neighbor distance from points1 to points2
31
+ nn = NearestNeighbors(n_neighbors=1, leaf_size=30, algorithm='kd_tree', metric=metric).fit(points2)
32
+ min_dist = nn.kneighbors(points1)[0]
33
+ return min_dist
34
+
35
+ def compute_mutual_nearest_distance(
36
+ points1: np.ndarray,
37
+ points2: np.ndarray,
38
+ metric: str = 'l2'
39
+ ) -> np.ndarray:
40
+ min_1_to_2 = compute_nearest_distance(points1, points2, metric=metric)
41
+ min_2_to_1 = compute_nearest_distance(points2, points1, metric=metric)
42
+ return min_1_to_2, min_2_to_1
43
+
44
+ def compute_mutual_nearest_distance_for_meshes(
45
+ mesh1: trimesh.Trimesh,
46
+ mesh2: trimesh.Trimesh,
47
+ num_samples: Optional[int] = 10000,
48
+ metric: str = 'l2'
49
+ ) -> Tuple[np.ndarray, np.ndarray]:
50
+ points1 = sample_from_mesh(mesh1, num_samples)
51
+ points2 = sample_from_mesh(mesh2, num_samples)
52
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance(points1, points2, metric=metric)
53
+ return min_1_to_2, min_2_to_1
54
+
55
+ def compute_chamfer_distance(
56
+ mesh1: trimesh.Trimesh,
57
+ mesh2: trimesh.Trimesh,
58
+ num_samples: int = 10000,
59
+ metric: str = 'l2'
60
+ ):
61
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
62
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
63
+ return chamfer_dist
64
+
65
+ def compute_f_score(
66
+ mesh1: trimesh.Trimesh,
67
+ mesh2: trimesh.Trimesh,
68
+ num_samples: int = 10000,
69
+ threshold: float = 0.1,
70
+ metric: str = 'l2'
71
+ ):
72
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
73
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
74
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
75
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
76
+ return fscore
77
+
78
+ def compute_cd_and_f_score(
79
+ mesh1: trimesh.Trimesh,
80
+ mesh2: trimesh.Trimesh,
81
+ num_samples: Optional[int] = 10000,
82
+ threshold: float = 0.1,
83
+ metric: str = 'l2'
84
+ ):
85
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
86
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
87
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
88
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
89
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
90
+ return chamfer_dist, fscore
91
+
92
+ def compute_cd_and_f_score_in_training(
93
+ gt_surface: np.ndarray,
94
+ pred_mesh: trimesh.Trimesh,
95
+ num_samples: int = 204800,
96
+ threshold: float = 0.1,
97
+ metric: str = 'l2'
98
+ ):
99
+ gt_points = gt_surface[:, :3]
100
+ num_samples = max(num_samples, gt_points.shape[0])
101
+ gt_points = gt_points[np.random.choice(gt_points.shape[0], num_samples, replace=False)]
102
+ pred_points = sample_from_mesh(pred_mesh, num_samples)
103
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance(gt_points, pred_points, metric=metric)
104
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
105
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
106
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
107
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
108
+ return chamfer_dist, fscore
109
+
110
+ def get_voxel_set(
111
+ mesh: trimesh.Trimesh,
112
+ num_grids: int = 64,
113
+ scale: float = 2.0,
114
+ ):
115
+ if not isinstance(mesh, trimesh.Trimesh):
116
+ raise ValueError("mesh must be a trimesh.Trimesh object")
117
+ pitch = scale / num_grids
118
+ voxel_girds: trimesh.voxel.base.VoxelGrid = mesh.voxelized(pitch=pitch).fill()
119
+ voxels = set(map(tuple, np.round(voxel_girds.points / pitch).astype(int)))
120
+ return voxels
121
+
122
+ def compute_IoU(
123
+ mesh1: trimesh.Trimesh,
124
+ mesh2: trimesh.Trimesh,
125
+ num_grids: int = 64,
126
+ scale: float = 2.0,
127
+ ):
128
+ if not isinstance(mesh1, trimesh.Trimesh) or not isinstance(mesh2, trimesh.Trimesh):
129
+ raise ValueError("mesh1 and mesh2 must be trimesh.Trimesh objects")
130
+ voxels1 = get_voxel_set(mesh1, num_grids, scale)
131
+ voxels2 = get_voxel_set(mesh2, num_grids, scale)
132
+ intersection = voxels1 & voxels2
133
+ union = voxels1 | voxels2
134
+ iou = len(intersection) / len(union) if len(union) > 0 else 0.0
135
+ return iou
136
+
137
+ def compute_IoU_for_scene(
138
+ scene: Union[trimesh.Scene, List[trimesh.Trimesh]],
139
+ num_grids: int = 64,
140
+ scale: float = 2.0,
141
+ return_type: Literal["iou", "iou_list"] = "iou",
142
+ ):
143
+ if isinstance(scene, trimesh.Scene):
144
+ scene = scene.dump()
145
+ if isinstance(scene, list) and len(scene) > 1 and isinstance(scene[0], trimesh.Trimesh):
146
+ meshes = scene
147
+ else:
148
+ raise ValueError("scene must be a trimesh.Scene object or a list of trimesh.Trimesh objects")
149
+ ious = []
150
+ for i in range(len(meshes)):
151
+ for j in range(i+1, len(meshes)):
152
+ iou = compute_IoU(meshes[i], meshes[j], num_grids, scale)
153
+ ious.append(iou)
154
+ if return_type == "iou":
155
+ return np.mean(ious)
156
+ elif return_type == "iou_list":
157
+ return ious
158
+ else:
159
+ raise ValueError("return_type must be 'iou' or 'iou_list'")
src/utils/render_utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import os
4
+ import numpy as np
5
+ from PIL import Image
6
+ import trimesh
7
+ from trimesh.transformations import rotation_matrix
8
+ import pyrender
9
+ from diffusers.utils import export_to_video
10
+ from diffusers.utils.loading_utils import load_video
11
+ import torch
12
+ from torchvision.utils import make_grid
13
+
14
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
15
+
16
+ def render(
17
+ scene: pyrender.Scene,
18
+ renderer: pyrender.Renderer,
19
+ camera: pyrender.Camera,
20
+ pose: np.ndarray,
21
+ light: Optional[pyrender.Light] = None,
22
+ normalize_depth: bool = False,
23
+ flags: int = pyrender.constants.RenderFlags.NONE,
24
+ return_type: Literal['pil', 'ndarray'] = 'pil'
25
+ ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[Image.Image, Image.Image]]:
26
+ camera_node = scene.add(camera, pose=pose)
27
+ if light is not None:
28
+ light_node = scene.add(light, pose=pose)
29
+ image, depth = renderer.render(
30
+ scene,
31
+ flags=flags
32
+ )
33
+ scene.remove_node(camera_node)
34
+ if light is not None:
35
+ scene.remove_node(light_node)
36
+ if normalize_depth or return_type == 'pil':
37
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
38
+ if return_type == 'pil':
39
+ image = Image.fromarray(image)
40
+ depth = Image.fromarray(depth.astype(np.uint8))
41
+ return image, depth
42
+
43
+ def rotation_matrix_from_vectors(vec1, vec2):
44
+ a, b = vec1 / np.linalg.norm(vec1), vec2 / np.linalg.norm(vec2)
45
+ v = np.cross(a, b)
46
+ c = np.dot(a, b)
47
+ s = np.linalg.norm(v)
48
+ if s == 0:
49
+ return np.eye(3) if c > 0 else -np.eye(3)
50
+ kmat = np.array([
51
+ [0, -v[2], v[1]],
52
+ [v[2], 0, -v[0]],
53
+ [-v[1], v[0], 0]
54
+ ])
55
+ return np.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2))
56
+
57
+ def create_circular_camera_positions(
58
+ num_views: int,
59
+ radius: float,
60
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0])
61
+ ) -> List[np.ndarray]:
62
+ # Create a list of positions for a circular camera trajectory
63
+ # around the given axis with the given radius.
64
+ positions = []
65
+ axis = axis / np.linalg.norm(axis)
66
+ for i in range(num_views):
67
+ theta = 2 * np.pi * i / num_views
68
+ position = np.array([
69
+ np.sin(theta) * radius,
70
+ 0.0,
71
+ np.cos(theta) * radius
72
+ ])
73
+ if not np.allclose(axis, np.array([0.0, 1.0, 0.0])):
74
+ R = rotation_matrix_from_vectors(np.array([0.0, 1.0, 0.0]), axis)
75
+ position = R @ position
76
+ positions.append(position)
77
+ return positions
78
+
79
+ def create_circular_camera_poses(
80
+ num_views: int,
81
+ radius: float,
82
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0])
83
+ ) -> List[np.ndarray]:
84
+ # Create a list of poses for a circular camera trajectory
85
+ # around the given axis with the given radius.
86
+ # The camera always looks at the origin.
87
+ # The up vector is always [0, 1, 0].
88
+ canonical_pose = np.array([
89
+ [1.0, 0.0, 0.0, 0.0],
90
+ [0.0, 1.0, 0.0, 0.0],
91
+ [0.0, 0.0, 1.0, radius],
92
+ [0.0, 0.0, 0.0, 1.0]
93
+ ])
94
+ poses = []
95
+ for i in range(num_views):
96
+ theta = 2 * np.pi * i / num_views
97
+ R = rotation_matrix(
98
+ angle=theta,
99
+ direction=axis,
100
+ point=[0, 0, 0]
101
+ )
102
+ pose = R @ canonical_pose
103
+ poses.append(pose)
104
+ return poses
105
+
106
+ def render_views_around_mesh(
107
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
108
+ num_views: int = 36,
109
+ radius: float = 3.5,
110
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
111
+ image_size: tuple = (512, 512),
112
+ fov: float = 40.0,
113
+ light_intensity: Optional[float] = 5.0,
114
+ znear: float = 0.1,
115
+ zfar: float = 10.0,
116
+ normalize_depth: bool = False,
117
+ flags: int = pyrender.constants.RenderFlags.NONE,
118
+ return_depth: bool = False,
119
+ return_type: Literal['pil', 'ndarray'] = 'pil'
120
+ ) -> Union[
121
+ List[Image.Image],
122
+ List[np.ndarray],
123
+ Tuple[List[Image.Image], List[Image.Image]],
124
+ Tuple[List[np.ndarray], List[np.ndarray]]
125
+ ]:
126
+
127
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
128
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
129
+ if isinstance(mesh, trimesh.Trimesh):
130
+ mesh = trimesh.Scene(mesh)
131
+
132
+ scene = pyrender.Scene.from_trimesh_scene(mesh)
133
+ light = pyrender.DirectionalLight(
134
+ color=np.ones(3),
135
+ intensity=light_intensity
136
+ ) if light_intensity is not None else None
137
+ camera = pyrender.PerspectiveCamera(
138
+ yfov=np.deg2rad(fov),
139
+ aspectRatio=image_size[0]/image_size[1],
140
+ znear=znear,
141
+ zfar=zfar
142
+ )
143
+ renderer = pyrender.OffscreenRenderer(*image_size)
144
+
145
+ camera_poses = create_circular_camera_poses(
146
+ num_views,
147
+ radius,
148
+ axis = axis
149
+ )
150
+
151
+ images, depths = [], []
152
+ for pose in camera_poses:
153
+ image, depth = render(
154
+ scene, renderer, camera, pose, light,
155
+ normalize_depth=normalize_depth,
156
+ flags=flags,
157
+ return_type=return_type
158
+ )
159
+ images.append(image)
160
+ depths.append(depth)
161
+
162
+ renderer.delete()
163
+
164
+ if return_depth:
165
+ return images, depths
166
+ return images
167
+
168
+ def render_normal_views_around_mesh(
169
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
170
+ num_views: int = 36,
171
+ radius: float = 3.5,
172
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
173
+ image_size: tuple = (512, 512),
174
+ fov: float = 40.0,
175
+ light_intensity: Optional[float] = 5.0,
176
+ znear: float = 0.1,
177
+ zfar: float = 10.0,
178
+ normalize_depth: bool = False,
179
+ flags: int = pyrender.constants.RenderFlags.NONE,
180
+ return_depth: bool = False,
181
+ return_type: Literal['pil', 'ndarray'] = 'pil'
182
+ ) -> Union[
183
+ List[Image.Image],
184
+ List[np.ndarray],
185
+ Tuple[List[Image.Image], List[Image.Image]],
186
+ Tuple[List[np.ndarray], List[np.ndarray]]
187
+ ]:
188
+
189
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
190
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
191
+ if isinstance(mesh, trimesh.Scene):
192
+ mesh = mesh.to_geometry()
193
+ normals = mesh.vertex_normals
194
+ colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
195
+ mesh.visual = trimesh.visual.ColorVisuals(
196
+ mesh=mesh,
197
+ vertex_colors=colors
198
+ )
199
+ mesh = trimesh.Scene(mesh)
200
+ return render_views_around_mesh(
201
+ mesh, num_views, radius, axis,
202
+ image_size, fov, light_intensity, znear, zfar,
203
+ normalize_depth, flags,
204
+ return_depth, return_type
205
+ )
206
+
207
+ def create_camera_pose_on_sphere(
208
+ azimuth: float = 0.0, # in degrees
209
+ elevation: float = 0.0, # in degrees
210
+ radius: float = 3.5,
211
+ ) -> np.ndarray:
212
+ # Create a camera pose for a given azimuth and elevation
213
+ # with the given radius.
214
+ # The camera always looks at the origin.
215
+ # The up vector is always [0, 1, 0].
216
+ canonical_pose = np.array([
217
+ [1.0, 0.0, 0.0, 0.0],
218
+ [0.0, 1.0, 0.0, 0.0],
219
+ [0.0, 0.0, 1.0, radius],
220
+ [0.0, 0.0, 0.0, 1.0]
221
+ ])
222
+ azimuth = np.deg2rad(azimuth)
223
+ elevation = np.deg2rad(elevation)
224
+ position = np.array([
225
+ np.cos(elevation) * np.sin(azimuth),
226
+ np.sin(elevation),
227
+ np.cos(elevation) * np.cos(azimuth),
228
+ ])
229
+ R = np.eye(4)
230
+ R[:3, :3] = rotation_matrix_from_vectors(
231
+ np.array([0.0, 0.0, 1.0]),
232
+ position
233
+ )
234
+ pose = R @ canonical_pose
235
+ return pose
236
+
237
+ def render_single_view(
238
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
239
+ azimuth: float = 0.0, # in degrees
240
+ elevation: float = 0.0, # in degrees
241
+ radius: float = 3.5,
242
+ image_size: tuple = (512, 512),
243
+ fov: float = 40.0,
244
+ light_intensity: Optional[float] = 5.0,
245
+ num_env_lights: int = 0,
246
+ znear: float = 0.1,
247
+ zfar: float = 10.0,
248
+ normalize_depth: bool = False,
249
+ flags: int = pyrender.constants.RenderFlags.NONE,
250
+ return_depth: bool = False,
251
+ return_type: Literal['pil', 'ndarray'] = 'pil'
252
+ ) -> Union[
253
+ Image.Image,
254
+ np.ndarray,
255
+ Tuple[Image.Image, Image.Image],
256
+ Tuple[np.ndarray, np.ndarray]
257
+ ]:
258
+
259
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
260
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
261
+ if isinstance(mesh, trimesh.Trimesh):
262
+ mesh = trimesh.Scene(mesh)
263
+
264
+ scene = pyrender.Scene.from_trimesh_scene(mesh)
265
+ light = pyrender.DirectionalLight(
266
+ color=np.ones(3),
267
+ intensity=light_intensity
268
+ ) if light_intensity is not None else None
269
+ camera = pyrender.PerspectiveCamera(
270
+ yfov=np.deg2rad(fov),
271
+ aspectRatio=image_size[0]/image_size[1],
272
+ znear=znear,
273
+ zfar=zfar
274
+ )
275
+ renderer = pyrender.OffscreenRenderer(*image_size)
276
+
277
+ camera_pose = create_camera_pose_on_sphere(
278
+ azimuth,
279
+ elevation,
280
+ radius
281
+ )
282
+
283
+ if num_env_lights > 0:
284
+ env_light_poses = create_circular_camera_poses(
285
+ num_env_lights,
286
+ radius,
287
+ axis = np.array([0.0, 1.0, 0.0])
288
+ )
289
+ for pose in env_light_poses:
290
+ scene.add(pyrender.DirectionalLight(
291
+ color=np.ones(3),
292
+ intensity=light_intensity
293
+ ), pose=pose)
294
+ # set light to None
295
+ light = None
296
+
297
+ image, depth = render(
298
+ scene, renderer, camera, camera_pose, light,
299
+ normalize_depth=normalize_depth,
300
+ flags=flags,
301
+ return_type=return_type
302
+ )
303
+ renderer.delete()
304
+
305
+ if return_depth:
306
+ return image, depth
307
+ return image
308
+
309
+ def render_normal_single_view(
310
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
311
+ azimuth: float = 0.0, # in degrees
312
+ elevation: float = 0.0, # in degrees
313
+ radius: float = 3.5,
314
+ image_size: tuple = (512, 512),
315
+ fov: float = 40.0,
316
+ light_intensity: Optional[float] = 5.0,
317
+ znear: float = 0.1,
318
+ zfar: float = 10.0,
319
+ normalize_depth: bool = False,
320
+ flags: int = pyrender.constants.RenderFlags.NONE,
321
+ return_depth: bool = False,
322
+ return_type: Literal['pil', 'ndarray'] = 'pil'
323
+ ) -> Union[
324
+ Image.Image,
325
+ np.ndarray,
326
+ Tuple[Image.Image, Image.Image],
327
+ Tuple[np.ndarray, np.ndarray]
328
+ ]:
329
+
330
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
331
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
332
+ if isinstance(mesh, trimesh.Scene):
333
+ mesh = mesh.to_geometry()
334
+ normals = mesh.vertex_normals
335
+ colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
336
+ mesh.visual = trimesh.visual.ColorVisuals(
337
+ mesh=mesh,
338
+ vertex_colors=colors
339
+ )
340
+ mesh = trimesh.Scene(mesh)
341
+ return render_single_view(
342
+ mesh, azimuth, elevation, radius,
343
+ image_size, fov, light_intensity, znear, zfar,
344
+ normalize_depth, flags,
345
+ return_depth, return_type
346
+ )
347
+
348
+ def export_renderings(
349
+ images: List[Image.Image],
350
+ export_path: str,
351
+ fps: int = 36,
352
+ loop: int = 0
353
+ ):
354
+ export_type = export_path.split('.')[-1]
355
+ if export_type == 'mp4':
356
+ export_to_video(
357
+ images,
358
+ export_path,
359
+ fps=fps,
360
+ )
361
+ elif export_type == 'gif':
362
+ duration = 1000 / fps
363
+ images[0].save(
364
+ export_path,
365
+ save_all=True,
366
+ append_images=images[1:],
367
+ duration=duration,
368
+ loop=loop
369
+ )
370
+ else:
371
+ raise ValueError(f'Unknown export type: {export_type}')
372
+
373
+ def make_grid_for_images_or_videos(
374
+ images_or_videos: Union[List[Image.Image], List[List[Image.Image]]],
375
+ nrow: int = 4,
376
+ padding: int = 0,
377
+ pad_value: int = 0,
378
+ image_size: tuple = (512, 512),
379
+ return_type: Literal['pil', 'ndarray'] = 'pil'
380
+ ) -> Union[Image.Image, List[Image.Image], np.ndarray]:
381
+ if isinstance(images_or_videos[0], Image.Image):
382
+ images = [np.array(image.resize(image_size).convert('RGB')) for image in images_or_videos]
383
+ images = np.stack(images, axis=0).transpose(0, 3, 1, 2) # [N, C, H, W]
384
+ images = torch.from_numpy(images)
385
+ image_grid = make_grid(
386
+ images,
387
+ nrow=nrow,
388
+ padding=padding,
389
+ pad_value=pad_value,
390
+ normalize=False
391
+ ) # [C, H', W']
392
+ image_grid = image_grid.cpu().numpy()
393
+ if return_type == 'pil':
394
+ image_grid = Image.fromarray(image_grid.transpose(1, 2, 0))
395
+ return image_grid
396
+ elif isinstance(images_or_videos[0], list) and isinstance(images_or_videos[0][0], Image.Image):
397
+ image_grids = []
398
+ for i in range(len(images_or_videos[0])):
399
+ images = [video[i] for video in images_or_videos]
400
+ image_grid = make_grid_for_images_or_videos(
401
+ images,
402
+ nrow=nrow,
403
+ padding=padding,
404
+ return_type=return_type
405
+ )
406
+ image_grids.append(image_grid)
407
+ if return_type == 'ndarray':
408
+ image_grids = np.stack(image_grids, axis=0)
409
+ return image_grids
410
+ else:
411
+ raise ValueError(f'Unknown input type: {type(images_or_videos[0])}')
src/utils/smoothing.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2012-2015, P. M. Neila
4
+ # All rights reserved.
5
+
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+
9
+ # * Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+
12
+ # * Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+
16
+ # * Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ """
32
+ Utilities for smoothing the occ/sdf grids.
33
+ """
34
+
35
+ import logging
36
+ from typing import Tuple
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torch.nn.functional as F
41
+ from scipy import ndimage as ndi
42
+ from scipy import sparse
43
+
44
+ __all__ = [
45
+ "smooth",
46
+ "smooth_constrained",
47
+ "smooth_gaussian",
48
+ "signed_distance_function",
49
+ "smooth_gpu",
50
+ "smooth_constrained_gpu",
51
+ "smooth_gaussian_gpu",
52
+ "signed_distance_function_gpu",
53
+ ]
54
+
55
+
56
+ def _build_variable_indices(band: np.ndarray) -> np.ndarray:
57
+ num_variables = np.count_nonzero(band)
58
+ variable_indices = np.full(band.shape, -1, dtype=np.int_)
59
+ variable_indices[band] = np.arange(num_variables)
60
+ return variable_indices
61
+
62
+
63
+ def _buildq3d(variable_indices: np.ndarray):
64
+ """
65
+ Builds the filterq matrix for the given variables.
66
+ """
67
+
68
+ num_variables = variable_indices.max() + 1
69
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
70
+
71
+ # Pad variable_indices to simplify out-of-bounds accesses
72
+ variable_indices = np.pad(
73
+ variable_indices, [(0, 1), (0, 1), (0, 1)], mode="constant", constant_values=-1
74
+ )
75
+
76
+ coords = np.nonzero(variable_indices >= 0)
77
+ for count, (i, j, k) in enumerate(zip(*coords)):
78
+
79
+ assert variable_indices[i, j, k] == count
80
+
81
+ filterq[3 * count, count] = -2
82
+ neighbor = variable_indices[i - 1, j, k]
83
+ if neighbor >= 0:
84
+ filterq[3 * count, neighbor] = 1
85
+ else:
86
+ filterq[3 * count, count] += 1
87
+
88
+ neighbor = variable_indices[i + 1, j, k]
89
+ if neighbor >= 0:
90
+ filterq[3 * count, neighbor] = 1
91
+ else:
92
+ filterq[3 * count, count] += 1
93
+
94
+ filterq[3 * count + 1, count] = -2
95
+ neighbor = variable_indices[i, j - 1, k]
96
+ if neighbor >= 0:
97
+ filterq[3 * count + 1, neighbor] = 1
98
+ else:
99
+ filterq[3 * count + 1, count] += 1
100
+
101
+ neighbor = variable_indices[i, j + 1, k]
102
+ if neighbor >= 0:
103
+ filterq[3 * count + 1, neighbor] = 1
104
+ else:
105
+ filterq[3 * count + 1, count] += 1
106
+
107
+ filterq[3 * count + 2, count] = -2
108
+ neighbor = variable_indices[i, j, k - 1]
109
+ if neighbor >= 0:
110
+ filterq[3 * count + 2, neighbor] = 1
111
+ else:
112
+ filterq[3 * count + 2, count] += 1
113
+
114
+ neighbor = variable_indices[i, j, k + 1]
115
+ if neighbor >= 0:
116
+ filterq[3 * count + 2, neighbor] = 1
117
+ else:
118
+ filterq[3 * count + 2, count] += 1
119
+
120
+ filterq = filterq.tocsr()
121
+ return filterq.T.dot(filterq)
122
+
123
+
124
+ def _buildq3d_gpu(variable_indices: torch.Tensor, chunk_size=10000):
125
+ """
126
+ Builds the filterq matrix for the given variables on GPU, using chunking to reduce memory usage.
127
+ """
128
+ device = variable_indices.device
129
+ num_variables = variable_indices.max().item() + 1
130
+
131
+ # Pad variable_indices to simplify out-of-bounds accesses
132
+ variable_indices = torch.nn.functional.pad(
133
+ variable_indices, (0, 1, 0, 1, 0, 1), mode="constant", value=-1
134
+ )
135
+
136
+ coords = torch.nonzero(variable_indices >= 0)
137
+ i, j, k = coords[:, 0], coords[:, 1], coords[:, 2]
138
+
139
+ # Function to process a chunk of data
140
+ def process_chunk(start, end):
141
+ row_indices = []
142
+ col_indices = []
143
+ values = []
144
+
145
+ for axis in range(3):
146
+ row_indices.append(3 * torch.arange(start, end, device=device) + axis)
147
+ col_indices.append(
148
+ variable_indices[i[start:end], j[start:end], k[start:end]]
149
+ )
150
+ values.append(torch.full((end - start,), -2, device=device))
151
+
152
+ for offset in [-1, 1]:
153
+ if axis == 0:
154
+ neighbor = variable_indices[
155
+ i[start:end] + offset, j[start:end], k[start:end]
156
+ ]
157
+ elif axis == 1:
158
+ neighbor = variable_indices[
159
+ i[start:end], j[start:end] + offset, k[start:end]
160
+ ]
161
+ else:
162
+ neighbor = variable_indices[
163
+ i[start:end], j[start:end], k[start:end] + offset
164
+ ]
165
+
166
+ mask = neighbor >= 0
167
+ row_indices.append(
168
+ 3 * torch.arange(start, end, device=device)[mask] + axis
169
+ )
170
+ col_indices.append(neighbor[mask])
171
+ values.append(torch.ones(mask.sum(), device=device))
172
+
173
+ # Add 1 to the diagonal for out-of-bounds neighbors
174
+ row_indices.append(
175
+ 3 * torch.arange(start, end, device=device)[~mask] + axis
176
+ )
177
+ col_indices.append(
178
+ variable_indices[i[start:end], j[start:end], k[start:end]][~mask]
179
+ )
180
+ values.append(torch.ones((~mask).sum(), device=device))
181
+
182
+ return torch.cat(row_indices), torch.cat(col_indices), torch.cat(values)
183
+
184
+ # Process data in chunks
185
+ all_row_indices = []
186
+ all_col_indices = []
187
+ all_values = []
188
+
189
+ for start in range(0, coords.shape[0], chunk_size):
190
+ end = min(start + chunk_size, coords.shape[0])
191
+ row_indices, col_indices, values = process_chunk(start, end)
192
+ all_row_indices.append(row_indices)
193
+ all_col_indices.append(col_indices)
194
+ all_values.append(values)
195
+
196
+ # Concatenate all chunks
197
+ row_indices = torch.cat(all_row_indices)
198
+ col_indices = torch.cat(all_col_indices)
199
+ values = torch.cat(all_values)
200
+
201
+ # Create sparse tensor
202
+ indices = torch.stack([row_indices, col_indices])
203
+ filterq = torch.sparse_coo_tensor(
204
+ indices, values, (3 * num_variables, num_variables)
205
+ )
206
+
207
+ # Compute filterq.T @ filterq
208
+ return torch.sparse.mm(filterq.t(), filterq)
209
+
210
+
211
+ # Usage example:
212
+ # variable_indices = torch.tensor(...).cuda() # Your input tensor on GPU
213
+ # result = _buildq3d_gpu(variable_indices)
214
+
215
+
216
+ def _buildq2d(variable_indices: np.ndarray):
217
+ """
218
+ Builds the filterq matrix for the given variables.
219
+
220
+ Version for 2 dimensions.
221
+ """
222
+
223
+ num_variables = variable_indices.max() + 1
224
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
225
+
226
+ # Pad variable_indices to simplify out-of-bounds accesses
227
+ variable_indices = np.pad(
228
+ variable_indices, [(0, 1), (0, 1)], mode="constant", constant_values=-1
229
+ )
230
+
231
+ coords = np.nonzero(variable_indices >= 0)
232
+ for count, (i, j) in enumerate(zip(*coords)):
233
+ assert variable_indices[i, j] == count
234
+
235
+ filterq[2 * count, count] = -2
236
+ neighbor = variable_indices[i - 1, j]
237
+ if neighbor >= 0:
238
+ filterq[2 * count, neighbor] = 1
239
+ else:
240
+ filterq[2 * count, count] += 1
241
+
242
+ neighbor = variable_indices[i + 1, j]
243
+ if neighbor >= 0:
244
+ filterq[2 * count, neighbor] = 1
245
+ else:
246
+ filterq[2 * count, count] += 1
247
+
248
+ filterq[2 * count + 1, count] = -2
249
+ neighbor = variable_indices[i, j - 1]
250
+ if neighbor >= 0:
251
+ filterq[2 * count + 1, neighbor] = 1
252
+ else:
253
+ filterq[2 * count + 1, count] += 1
254
+
255
+ neighbor = variable_indices[i, j + 1]
256
+ if neighbor >= 0:
257
+ filterq[2 * count + 1, neighbor] = 1
258
+ else:
259
+ filterq[2 * count + 1, count] += 1
260
+
261
+ filterq = filterq.tocsr()
262
+ return filterq.T.dot(filterq)
263
+
264
+
265
+ def _jacobi(
266
+ filterq,
267
+ x0: np.ndarray,
268
+ lower_bound: np.ndarray,
269
+ upper_bound: np.ndarray,
270
+ max_iters: int = 10,
271
+ rel_tol: float = 1e-6,
272
+ weight: float = 0.5,
273
+ ):
274
+ """Jacobi method with constraints."""
275
+
276
+ jacobi_r = sparse.lil_matrix(filterq)
277
+ shp = jacobi_r.shape
278
+ jacobi_d = 1.0 / filterq.diagonal()
279
+ jacobi_r.setdiag((0,) * shp[0])
280
+ jacobi_r = jacobi_r.tocsr()
281
+
282
+ x = x0
283
+
284
+ # We check the stopping criterion each 10 iterations
285
+ check_each = 10
286
+ cum_rel_tol = 1 - (1 - rel_tol) ** check_each
287
+
288
+ energy_now = np.dot(x, filterq.dot(x)) / 2
289
+ logging.info("Energy at iter %d: %.6g", 0, energy_now)
290
+ for i in range(max_iters):
291
+
292
+ x_1 = -jacobi_d * jacobi_r.dot(x)
293
+ x = weight * x_1 + (1 - weight) * x
294
+
295
+ # Constraints.
296
+ x = np.maximum(x, lower_bound)
297
+ x = np.minimum(x, upper_bound)
298
+
299
+ # Stopping criterion
300
+ if (i + 1) % check_each == 0:
301
+ # Update energy
302
+ energy_before = energy_now
303
+ energy_now = np.dot(x, filterq.dot(x)) / 2
304
+
305
+ logging.info("Energy at iter %d: %.6g", i + 1, energy_now)
306
+
307
+ # Check stopping criterion
308
+ cum_rel_improvement = (energy_before - energy_now) / energy_before
309
+ if cum_rel_improvement < cum_rel_tol:
310
+ break
311
+
312
+ return x
313
+
314
+
315
+ def signed_distance_function(
316
+ levelset: np.ndarray, band_radius: int
317
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
318
+ """
319
+ Return the distance to the 0.5 levelset of a function, the mask of the
320
+ border (i.e., the nearest cells to the 0.5 level-set) and the mask of the
321
+ band (i.e., the cells of the function whose distance to the 0.5 level-set
322
+ is less of equal to `band_radius`).
323
+ """
324
+
325
+ binary_array = np.where(levelset > 0, True, False)
326
+
327
+ # Compute the band and the border.
328
+ dist_func = ndi.distance_transform_edt
329
+ distance = np.where(
330
+ binary_array, dist_func(binary_array) - 0.5, -dist_func(~binary_array) + 0.5
331
+ )
332
+ border = np.abs(distance) < 1
333
+ band = np.abs(distance) <= band_radius
334
+
335
+ return distance, border, band
336
+
337
+
338
+ def signed_distance_function_iso0(
339
+ levelset: np.ndarray, band_radius: int
340
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
341
+ """
342
+ Return the distance to the 0 levelset of a function, the mask of the
343
+ border (i.e., the nearest cells to the 0 level-set) and the mask of the
344
+ band (i.e., the cells of the function whose distance to the 0 level-set
345
+ is less of equal to `band_radius`).
346
+ """
347
+
348
+ binary_array = levelset > 0
349
+
350
+ # Compute the band and the border.
351
+ dist_func = ndi.distance_transform_edt
352
+ distance = np.where(
353
+ binary_array, dist_func(binary_array), -dist_func(~binary_array)
354
+ )
355
+ border = np.zeros_like(levelset, dtype=bool)
356
+ border[:-1, :, :] |= levelset[:-1, :, :] * levelset[1:, :, :] <= 0
357
+ border[:, :-1, :] |= levelset[:, :-1, :] * levelset[:, 1:, :] <= 0
358
+ border[:, :, :-1] |= levelset[:, :, :-1] * levelset[:, :, 1:] <= 0
359
+ band = np.abs(distance) <= band_radius
360
+
361
+ return distance, border, band
362
+
363
+
364
+ def signed_distance_function_gpu(levelset: torch.Tensor, band_radius: int):
365
+ binary_array = (levelset > 0).float()
366
+
367
+ # Compute distance transform
368
+ dist_pos = (
369
+ F.max_pool3d(
370
+ -binary_array.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
371
+ )
372
+ .squeeze(0)
373
+ .squeeze(0)
374
+ + binary_array
375
+ )
376
+ dist_neg = F.max_pool3d(
377
+ (binary_array - 1).unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
378
+ ).squeeze(0).squeeze(0) + (1 - binary_array)
379
+
380
+ distance = torch.where(binary_array > 0, dist_pos - 0.5, -dist_neg + 0.5)
381
+
382
+ # breakpoint()
383
+
384
+ # Use levelset as distance directly
385
+ # distance = levelset
386
+ # print(distance.shape)
387
+ # Compute border and band
388
+ border = torch.abs(distance) < 1
389
+ band = torch.abs(distance) <= band_radius
390
+
391
+ return distance, border, band
392
+
393
+
394
+ def smooth_constrained(
395
+ binary_array: np.ndarray,
396
+ band_radius: int = 4,
397
+ max_iters: int = 250,
398
+ rel_tol: float = 1e-6,
399
+ ) -> np.ndarray:
400
+ """
401
+ Implementation of the smoothing method from
402
+
403
+ "Surface Extraction from Binary Volumes with Higher-Order Smoothness"
404
+ Victor Lempitsky, CVPR10
405
+ """
406
+
407
+ # # Compute the distance map, the border and the band.
408
+ logging.info("Computing distance transform...")
409
+ # distance, _, band = signed_distance_function(binary_array, band_radius)
410
+ binary_array_gpu = torch.from_numpy(binary_array).cuda()
411
+ distance, _, band = signed_distance_function_gpu(binary_array_gpu, band_radius)
412
+ distance = distance.cpu().numpy()
413
+ band = band.cpu().numpy()
414
+
415
+ variable_indices = _build_variable_indices(band)
416
+
417
+ # Compute filterq.
418
+ logging.info("Building matrix filterq...")
419
+ if binary_array.ndim == 3:
420
+ filterq = _buildq3d(variable_indices)
421
+ # variable_indices_gpu = torch.from_numpy(variable_indices).cuda()
422
+ # filterq_gpu = _buildq3d_gpu(variable_indices_gpu)
423
+ # filterq = filterq_gpu.cpu().numpy()
424
+ elif binary_array.ndim == 2:
425
+ filterq = _buildq2d(variable_indices)
426
+ else:
427
+ raise ValueError("binary_array.ndim not in [2, 3]")
428
+
429
+ # Initialize the variables.
430
+ res = np.asarray(distance, dtype=np.double)
431
+ x = res[band]
432
+ upper_bound = np.where(x < 0, x, np.inf)
433
+ lower_bound = np.where(x > 0, x, -np.inf)
434
+
435
+ upper_bound[np.abs(upper_bound) < 1] = 0
436
+ lower_bound[np.abs(lower_bound) < 1] = 0
437
+
438
+ # Solve.
439
+ logging.info("Minimizing energy...")
440
+ x = _jacobi(
441
+ filterq=filterq,
442
+ x0=x,
443
+ lower_bound=lower_bound,
444
+ upper_bound=upper_bound,
445
+ max_iters=max_iters,
446
+ rel_tol=rel_tol,
447
+ )
448
+
449
+ res[band] = x
450
+ return res
451
+
452
+
453
+ def total_variation_denoising(x, weight=0.1, num_iterations=5, eps=1e-8):
454
+ diff_x = torch.diff(x, dim=0, prepend=x[:1])
455
+ diff_y = torch.diff(x, dim=1, prepend=x[:, :1])
456
+ diff_z = torch.diff(x, dim=2, prepend=x[:, :, :1])
457
+
458
+ norm = torch.sqrt(diff_x**2 + diff_y**2 + diff_z**2 + eps)
459
+
460
+ div_x = torch.diff(diff_x / norm, dim=0, append=diff_x[-1:] / norm[-1:])
461
+ div_y = torch.diff(diff_y / norm, dim=1, append=diff_y[:, -1:] / norm[:, -1:])
462
+ div_z = torch.diff(diff_z / norm, dim=2, append=diff_z[:, :, -1:] / norm[:, :, -1:])
463
+
464
+ return x - weight * (div_x + div_y + div_z)
465
+
466
+
467
+ def smooth_constrained_gpu(
468
+ binary_array: torch.Tensor,
469
+ band_radius: int = 4,
470
+ max_iters: int = 250,
471
+ rel_tol: float = 1e-4,
472
+ ):
473
+ distance, _, band = signed_distance_function_gpu(binary_array, band_radius)
474
+
475
+ # Initialize variables
476
+ x = distance[band]
477
+ upper_bound = torch.where(x < 0, x, torch.tensor(float("inf"), device=x.device))
478
+ lower_bound = torch.where(x > 0, x, torch.tensor(float("-inf"), device=x.device))
479
+
480
+ upper_bound[torch.abs(upper_bound) < 1] = 0
481
+ lower_bound[torch.abs(lower_bound) < 1] = 0
482
+
483
+ # Define the 3D Laplacian kernel
484
+ laplacian_kernel = torch.tensor(
485
+ [
486
+ [
487
+ [
488
+ [[0, 1, 0], [1, -6, 1], [0, 1, 0]],
489
+ [[1, 0, 1], [0, 0, 0], [1, 0, 1]],
490
+ [[0, 1, 0], [1, 0, 1], [0, 1, 0]],
491
+ ]
492
+ ]
493
+ ],
494
+ device=x.device,
495
+ ).float()
496
+
497
+ laplacian_kernel = laplacian_kernel / laplacian_kernel.abs().sum()
498
+
499
+ breakpoint()
500
+
501
+ # Simplified Jacobi iteration
502
+ for i in range(max_iters):
503
+ # Reshape x to 5D tensor (batch, channel, depth, height, width)
504
+ x_5d = x.view(1, 1, *band.shape)
505
+ x_3d = x.view(*band.shape)
506
+
507
+ # Apply 3D convolution
508
+ laplacian = F.conv3d(x_5d, laplacian_kernel, padding=1)
509
+
510
+ # Reshape back to original dimensions
511
+ laplacian = laplacian.view(x.shape)
512
+
513
+ # Use a small relaxation factor to improve stability
514
+ relaxation_factor = 0.1
515
+ tv_weight = 0.1
516
+ # x_new = x + relaxation_factor * laplacian
517
+ x_new = total_variation_denoising(x_3d, weight=tv_weight)
518
+ # Print laplacian min and max
519
+ # print(f"Laplacian min: {laplacian.min().item():.4f}, max: {laplacian.max().item():.4f}")
520
+
521
+ # Apply constraints
522
+ # Reshape x_new to match the dimensions of lower_bound and upper_bound
523
+ x_new = x_new.view(x.shape)
524
+ x_new = torch.clamp(x_new, min=lower_bound, max=upper_bound)
525
+
526
+ # Check for convergence
527
+ diff_norm = torch.norm(x_new - x)
528
+ print(diff_norm)
529
+ x_norm = torch.norm(x)
530
+
531
+ if x_norm > 1e-8: # Avoid division by very small numbers
532
+ relative_change = diff_norm / x_norm
533
+ if relative_change < rel_tol:
534
+ break
535
+ elif diff_norm < rel_tol: # If x_norm is very small, check absolute change
536
+ break
537
+
538
+ x = x_new
539
+
540
+ # Check for NaN and break if found, also check for inf
541
+ if torch.isnan(x).any() or torch.isinf(x).any():
542
+ print(f"NaN or Inf detected at iteration {i}")
543
+ breakpoint()
544
+ break
545
+
546
+ result = distance.clone()
547
+ result[band] = x
548
+ return result
549
+
550
+
551
+ def smooth_gaussian(binary_array: np.ndarray, sigma: float = 3) -> np.ndarray:
552
+ vol = np.float_(binary_array) - 0.5
553
+ return ndi.gaussian_filter(vol, sigma=sigma)
554
+
555
+
556
+ def smooth_gaussian_gpu(binary_array: torch.Tensor, sigma: float = 3):
557
+ # vol = binary_array.float()
558
+ vol = binary_array
559
+ kernel_size = int(2 * sigma + 1)
560
+ kernel = torch.ones(
561
+ 1,
562
+ 1,
563
+ kernel_size,
564
+ kernel_size,
565
+ kernel_size,
566
+ device=binary_array.device,
567
+ dtype=vol.dtype,
568
+ ) / (kernel_size**3)
569
+ return F.conv3d(
570
+ vol.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2
571
+ ).squeeze()
572
+
573
+
574
+ def smooth(binary_array: np.ndarray, method: str = "auto", **kwargs) -> np.ndarray:
575
+ """
576
+ Smooths the 0.5 level-set of a binary array. Returns a floating-point
577
+ array with a smoothed version of the original level-set in the 0 isovalue.
578
+
579
+ This function can apply two different methods:
580
+
581
+ - A constrained smoothing method which preserves details and fine
582
+ structures, but it is slow and requires a large amount of memory. This
583
+ method is recommended when the input array is small (smaller than
584
+ (500, 500, 500)).
585
+ - A Gaussian filter applied over the binary array. This method is fast, but
586
+ not very precise, as it can destroy fine details. It is only recommended
587
+ when the input array is large and the 0.5 level-set does not contain
588
+ thin structures.
589
+
590
+ Parameters
591
+ ----------
592
+ binary_array : ndarray
593
+ Input binary array with the 0.5 level-set to smooth.
594
+ method : str, one of ['auto', 'gaussian', 'constrained']
595
+ Smoothing method. If 'auto' is given, the method will be automatically
596
+ chosen based on the size of `binary_array`.
597
+
598
+ Parameters for 'gaussian'
599
+ -------------------------
600
+ sigma : float
601
+ Size of the Gaussian filter (default 3).
602
+
603
+ Parameters for 'constrained'
604
+ ----------------------------
605
+ max_iters : positive integer
606
+ Number of iterations of the constrained optimization method
607
+ (default 250).
608
+ rel_tol: float
609
+ Relative tolerance as a stopping criterion (default 1e-6).
610
+
611
+ Output
612
+ ------
613
+ res : ndarray
614
+ Floating-point array with a smoothed 0 level-set.
615
+ """
616
+
617
+ binary_array = np.asarray(binary_array)
618
+
619
+ if method == "auto":
620
+ if binary_array.size > 512**3:
621
+ method = "gaussian"
622
+ else:
623
+ method = "constrained"
624
+
625
+ if method == "gaussian":
626
+ return smooth_gaussian(binary_array, **kwargs)
627
+
628
+ if method == "constrained":
629
+ return smooth_constrained(binary_array, **kwargs)
630
+
631
+ raise ValueError("Unknown method '{}'".format(method))
632
+
633
+
634
+ def smooth_gpu(binary_array: torch.Tensor, method: str = "auto", **kwargs):
635
+ if method == "auto":
636
+ method = "gaussian" if binary_array.numel() > 512**3 else "constrained"
637
+
638
+ if method == "gaussian":
639
+ return smooth_gaussian_gpu(binary_array, **kwargs)
640
+ elif method == "constrained":
641
+ return smooth_constrained_gpu(binary_array, **kwargs)
642
+ else:
643
+ raise ValueError(f"Unknown method '{method}'")
src/utils/train_utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.typing_utils import *
2
+
3
+ import os
4
+ from omegaconf import OmegaConf
5
+
6
+ from torch import optim
7
+ from torch.optim import lr_scheduler
8
+ from diffusers.training_utils import *
9
+ from diffusers.optimization import get_scheduler
10
+
11
+ # https://github.com/huggingface/diffusers/pull/9812: fix `self.use_ema_warmup`
12
+ class MyEMAModel(EMAModel):
13
+ """
14
+ Exponential Moving Average of models weights
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ parameters: Iterable[torch.nn.Parameter],
20
+ decay: float = 0.9999,
21
+ min_decay: float = 0.0,
22
+ update_after_step: int = 0,
23
+ use_ema_warmup: bool = False,
24
+ inv_gamma: Union[float, int] = 1.0,
25
+ power: Union[float, int] = 2 / 3,
26
+ foreach: bool = False,
27
+ model_cls: Optional[Any] = None,
28
+ model_config: Dict[str, Any] = None,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Args:
33
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
34
+ decay (float): The decay factor for the exponential moving average.
35
+ min_decay (float): The minimum decay factor for the exponential moving average.
36
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
37
+ use_ema_warmup (bool): Whether to use EMA warmup.
38
+ inv_gamma (float):
39
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
40
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
41
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
42
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
43
+ weights will be stored on CPU.
44
+
45
+ @crowsonkb's notes on EMA Warmup:
46
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
47
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
48
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
49
+ at 215.4k steps).
50
+ """
51
+
52
+ if isinstance(parameters, torch.nn.Module):
53
+ deprecation_message = (
54
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
55
+ "Please pass the parameters of the module instead."
56
+ )
57
+ deprecate(
58
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
59
+ "1.0.0",
60
+ deprecation_message,
61
+ standard_warn=False,
62
+ )
63
+ parameters = parameters.parameters()
64
+
65
+ # # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
66
+ # use_ema_warmup = True
67
+
68
+ if kwargs.get("max_value", None) is not None:
69
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
70
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
71
+ decay = kwargs["max_value"]
72
+
73
+ if kwargs.get("min_value", None) is not None:
74
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
75
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
76
+ min_decay = kwargs["min_value"]
77
+
78
+ parameters = list(parameters)
79
+ self.shadow_params = [p.clone().detach() for p in parameters]
80
+
81
+ if kwargs.get("device", None) is not None:
82
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
83
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
84
+ self.to(device=kwargs["device"])
85
+
86
+ self.temp_stored_params = None
87
+
88
+ self.decay = decay
89
+ self.min_decay = min_decay
90
+ self.update_after_step = update_after_step
91
+ self.use_ema_warmup = use_ema_warmup
92
+ self.inv_gamma = inv_gamma
93
+ self.power = power
94
+ self.optimization_step = 0
95
+ self.cur_decay_value = None # set in `step()`
96
+ self.foreach = foreach
97
+
98
+ self.model_cls = model_cls
99
+ self.model_config = model_config
100
+
101
+ def get_decay(self, optimization_step: int) -> float:
102
+ """
103
+ Compute the decay factor for the exponential moving average.
104
+ """
105
+ step = max(0, optimization_step - self.update_after_step - 1)
106
+
107
+ if step <= 0:
108
+ return 0.0
109
+
110
+ if self.use_ema_warmup:
111
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
112
+ else:
113
+ # cur_decay_value = (1 + step) / (10 + step)
114
+ cur_decay_value = self.decay
115
+
116
+ cur_decay_value = min(cur_decay_value, self.decay)
117
+ # make sure decay is not smaller than min_decay
118
+ cur_decay_value = max(cur_decay_value, self.min_decay)
119
+ return cur_decay_value
120
+
121
+ def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig:
122
+ yaml_configs = OmegaConf.load(yaml_path)
123
+ cli_configs = OmegaConf.from_cli(cli_configs)
124
+
125
+ configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs)
126
+ OmegaConf.resolve(configs) # resolve ${...} placeholders
127
+ return configs
128
+
129
+ def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer:
130
+ if name == "adamw":
131
+ return optim.AdamW(params=params, **kwargs)
132
+ else:
133
+ raise NotImplementedError(f"Not implemented optimizer: {name}")
134
+
135
+ def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler:
136
+ if name == "one_cycle":
137
+ return lr_scheduler.OneCycleLR(
138
+ optimizer,
139
+ max_lr=kwargs["max_lr"],
140
+ total_steps=kwargs["total_steps"],
141
+ pct_start=kwargs["pct_start"],
142
+ )
143
+ elif name == "cosine_warmup":
144
+ return get_scheduler(
145
+ "cosine", optimizer,
146
+ num_warmup_steps=kwargs["num_warmup_steps"],
147
+ num_training_steps=kwargs["total_steps"],
148
+ )
149
+ elif name == "constant_warmup":
150
+ return get_scheduler(
151
+ "constant_with_warmup", optimizer,
152
+ num_warmup_steps=kwargs["num_warmup_steps"],
153
+ num_training_steps=kwargs["total_steps"],
154
+ )
155
+ elif name == "constant":
156
+ return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1)
157
+ elif name == "linear_decay":
158
+ return lr_scheduler.LambdaLR(
159
+ optimizer=optimizer,
160
+ lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]),
161
+ )
162
+ else:
163
+ raise NotImplementedError(f"Not implemented lr scheduler: {name}")
164
+
165
+ def save_experiment_params(
166
+ args: Namespace,
167
+ configs: DictConfig,
168
+ save_dir: str
169
+ ) -> Dict[str, Any]:
170
+ params = OmegaConf.merge(configs, {"args": {k: str(v) for k, v in vars(args).items()}})
171
+ OmegaConf.save(params, os.path.join(save_dir, "params.yaml"))
172
+ return dict(params)
173
+
174
+
175
+ def save_model_architecture(model: Module, save_dir: str) -> None:
176
+ num_buffers = sum(b.numel() for b in model.buffers())
177
+ num_params = sum(p.numel() for p in model.parameters())
178
+ num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
179
+ message = f"Number of buffers: {num_buffers}\n" +\
180
+ f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\
181
+ f"Model architecture:\n{model}"
182
+
183
+ with open(os.path.join(save_dir, "model.txt"), "w") as f:
184
+ f.write(message)
src/utils/typing_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ from argparse import Namespace
4
+ from collections import defaultdict
5
+ from omegaconf import DictConfig, ListConfig
6
+ from omegaconf.base import ContainerMetadata, Metadata
7
+ from omegaconf.nodes import AnyNode
8
+
9
+ from torch import Tensor
10
+ from torch.nn import Parameter, Module
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torch.optim import Optimizer
13
+ from torch.optim.lr_scheduler import LRScheduler
14
+ from torch.utils.data import DataLoader
15
+
16
+ from accelerate.optimizer import AcceleratedOptimizer
17
+ from accelerate.scheduler import AcceleratedScheduler
18
+ from accelerate.data_loader import DataLoaderShard
19
+
20
+