Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						1774ce2
	
0
								Parent(s):
							
							
Add files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +40 -0
- .gitignore +1 -0
- README.md +15 -0
- app.py +68 -0
- assets/teaser.jpg +0 -0
- custum_3d_diffusion/custum_modules/attention_processors.py +385 -0
- custum_3d_diffusion/custum_modules/unifield_processor.py +459 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py +298 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py +296 -0
- custum_3d_diffusion/modules.py +14 -0
- custum_3d_diffusion/trainings/__init__.py +0 -0
- custum_3d_diffusion/trainings/base.py +208 -0
- custum_3d_diffusion/trainings/config_classes.py +35 -0
- custum_3d_diffusion/trainings/image2image_trainer.py +86 -0
- custum_3d_diffusion/trainings/image2mvimage_trainer.py +139 -0
- custum_3d_diffusion/trainings/utils.py +25 -0
- gradio_app/__init__.py +0 -0
- gradio_app/all_models.py +22 -0
- gradio_app/custom_models/image2mvimage.yaml +63 -0
- gradio_app/custom_models/image2normal.yaml +61 -0
- gradio_app/custom_models/mvimg_prediction.py +59 -0
- gradio_app/custom_models/normal_prediction.py +28 -0
- gradio_app/custom_models/utils.py +75 -0
- gradio_app/examples/Groot.png +3 -0
- gradio_app/examples/aaa.png +3 -0
- gradio_app/examples/abma.png +3 -0
- gradio_app/examples/akun.png +3 -0
- gradio_app/examples/anya.png +3 -0
- gradio_app/examples/bag.png +3 -0
- gradio_app/examples/ex1.png +3 -0
- gradio_app/examples/ex2.png +3 -0
- gradio_app/examples/ex3.jpg +0 -0
- gradio_app/examples/ex4.png +3 -0
- gradio_app/examples/generated_1715761545_frame0.png +3 -0
- gradio_app/examples/generated_1715762357_frame0.png +3 -0
- gradio_app/examples/generated_1715763329_frame0.png +3 -0
- gradio_app/examples/hatsune_miku.png +3 -0
- gradio_app/examples/princess-large.png +3 -0
- gradio_app/gradio_3dgen.py +85 -0
- gradio_app/gradio_3dgen_steps.py +87 -0
- gradio_app/gradio_local.py +76 -0
- gradio_app/utils.py +112 -0
- mesh_reconstruction/func.py +133 -0
- mesh_reconstruction/opt.py +190 -0
- mesh_reconstruction/recon.py +59 -0
- mesh_reconstruction/refine.py +80 -0
- mesh_reconstruction/remesh.py +361 -0
- mesh_reconstruction/render.py +159 -0
- package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl +3 -0
- package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl +3 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            *.o filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            *.ninja_deps filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            *.so filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            *.whl filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            *.pyc
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Unique3D
         | 
| 3 | 
            +
            emoji: ⚡
         | 
| 4 | 
            +
            colorFrom: red
         | 
| 5 | 
            +
            colorTo: purple
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            python_version: 3.10.8
         | 
| 8 | 
            +
            sdk_version: 4.12.0
         | 
| 9 | 
            +
            app_file: app.py
         | 
| 10 | 
            +
            pinned: true
         | 
| 11 | 
            +
            short_description: Create a 1M faces 3D colored model from an image!
         | 
| 12 | 
            +
            license: mit
         | 
| 13 | 
            +
            ---
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import shlex
         | 
| 2 | 
            +
            import subprocess
         | 
| 3 | 
            +
            subprocess.run(
         | 
| 4 | 
            +
                shlex.split(
         | 
| 5 | 
            +
                    "pip install package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl --force-reinstall --no-deps"
         | 
| 6 | 
            +
                )
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
            subprocess.run(
         | 
| 9 | 
            +
                shlex.split(
         | 
| 10 | 
            +
                    "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
         | 
| 11 | 
            +
                )
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            if __name__ == "__main__":
         | 
| 15 | 
            +
                import os
         | 
| 16 | 
            +
                from huggingface_hub import login
         | 
| 17 | 
            +
                hf_token = os.environ.get("HF_TOKEN")
         | 
| 18 | 
            +
                login(token=hf_token)
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                import os
         | 
| 21 | 
            +
                import sys
         | 
| 22 | 
            +
                sys.path.append(os.curdir)
         | 
| 23 | 
            +
                import torch
         | 
| 24 | 
            +
                torch.set_float32_matmul_precision('medium')
         | 
| 25 | 
            +
                torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 26 | 
            +
                torch.set_grad_enabled(False)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            import fire
         | 
| 29 | 
            +
            import gradio as gr
         | 
| 30 | 
            +
            from gradio_app.gradio_3dgen import create_ui as create_3d_ui
         | 
| 31 | 
            +
            from gradio_app.all_models import model_zoo
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
         | 
| 35 | 
            +
            _DESCRIPTION = '''
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            <div>
         | 
| 38 | 
            +
                <a style="display:inline-block" href='https://github.com/AiuniAI/Unique3D'><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/AiuniAI/Unique3D?style=social">
         | 
| 39 | 
            +
            </a>
         | 
| 40 | 
            +
                <img alt="GitHub License" src="https://img.shields.io/github/license/AiuniAI/Unique3D">
         | 
| 41 | 
            +
            </div>
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [Online Demo](https://www.aiuni.ai/)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            * The demo is still under construction, and more features are expected to be implemented soon.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            * If the Huggingface Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.
         | 
| 50 | 
            +
            '''
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            def launch():
         | 
| 53 | 
            +
                model_zoo.init_models()
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                with gr.Blocks(
         | 
| 56 | 
            +
                    title=_TITLE,
         | 
| 57 | 
            +
                    # theme=gr.themes.Monochrome(),
         | 
| 58 | 
            +
                ) as demo:
         | 
| 59 | 
            +
                    with gr.Row():
         | 
| 60 | 
            +
                        with gr.Column(scale=1):
         | 
| 61 | 
            +
                            gr.Markdown('# ' + _TITLE)
         | 
| 62 | 
            +
                    gr.Markdown(_DESCRIPTION)
         | 
| 63 | 
            +
                    create_3d_ui("wkl")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                demo.queue().launch(share=True)
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
            if __name__ == '__main__':
         | 
| 68 | 
            +
                fire.Fire(launch)
         | 
    	
        assets/teaser.jpg
    ADDED
    
    |   | 
    	
        custum_3d_diffusion/custum_modules/attention_processors.py
    ADDED
    
    | @@ -0,0 +1,385 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any, Dict, Optional
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from diffusers.models.attention_processor import Attention
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
         | 
| 6 | 
            +
                if norm_type == "layernorm":
         | 
| 7 | 
            +
                    norm = torch.nn.LayerNorm(hidden_states_dim)
         | 
| 8 | 
            +
                else:
         | 
| 9 | 
            +
                    norm = torch.nn.Identity()
         | 
| 10 | 
            +
                attention = Attention(
         | 
| 11 | 
            +
                    query_dim=hidden_states_dim,
         | 
| 12 | 
            +
                    heads=8,
         | 
| 13 | 
            +
                    dim_head=hidden_states_dim // 8,
         | 
| 14 | 
            +
                    bias=True,
         | 
| 15 | 
            +
                )
         | 
| 16 | 
            +
                # NOTE: xformers 0.22 does not support batchsize >= 4096
         | 
| 17 | 
            +
                attention.xformers_not_supported = True # hacky solution
         | 
| 18 | 
            +
                return norm, attention
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            class ExtraAttnProc(torch.nn.Module):
         | 
| 21 | 
            +
                def __init__(
         | 
| 22 | 
            +
                    self,
         | 
| 23 | 
            +
                    chained_proc,
         | 
| 24 | 
            +
                    enabled=False,
         | 
| 25 | 
            +
                    name=None,
         | 
| 26 | 
            +
                    mode='extract',
         | 
| 27 | 
            +
                    with_proj_in=False,
         | 
| 28 | 
            +
                    proj_in_dim=768,
         | 
| 29 | 
            +
                    target_dim=None,
         | 
| 30 | 
            +
                    pixel_wise_crosspond=False,
         | 
| 31 | 
            +
                    norm_type="none",   # none or layernorm
         | 
| 32 | 
            +
                    crosspond_effect_on="all",  # all or first
         | 
| 33 | 
            +
                    crosspond_chain_pos="parralle",     # before or parralle or after
         | 
| 34 | 
            +
                    simple_3d=False,
         | 
| 35 | 
            +
                    views=4,
         | 
| 36 | 
            +
                ) -> None:
         | 
| 37 | 
            +
                    super().__init__()
         | 
| 38 | 
            +
                    self.enabled = enabled
         | 
| 39 | 
            +
                    self.chained_proc = chained_proc
         | 
| 40 | 
            +
                    self.name = name
         | 
| 41 | 
            +
                    self.mode = mode
         | 
| 42 | 
            +
                    self.with_proj_in=with_proj_in
         | 
| 43 | 
            +
                    self.proj_in_dim = proj_in_dim
         | 
| 44 | 
            +
                    self.target_dim = target_dim or proj_in_dim
         | 
| 45 | 
            +
                    self.hidden_states_dim = self.target_dim
         | 
| 46 | 
            +
                    self.pixel_wise_crosspond = pixel_wise_crosspond
         | 
| 47 | 
            +
                    self.crosspond_effect_on = crosspond_effect_on
         | 
| 48 | 
            +
                    self.crosspond_chain_pos = crosspond_chain_pos
         | 
| 49 | 
            +
                    self.views = views
         | 
| 50 | 
            +
                    self.simple_3d = simple_3d
         | 
| 51 | 
            +
                    if self.with_proj_in and self.enabled:
         | 
| 52 | 
            +
                        self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
         | 
| 53 | 
            +
                        if self.target_dim == self.proj_in_dim:
         | 
| 54 | 
            +
                            self.in_linear.weight.data = torch.eye(proj_in_dim)
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        self.in_linear = None
         | 
| 57 | 
            +
                    if self.pixel_wise_crosspond and self.enabled:
         | 
| 58 | 
            +
                        self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
         | 
| 61 | 
            +
                    hidden_states = self.crosspond_norm(hidden_states)
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                    batch, L, D = hidden_states.shape
         | 
| 64 | 
            +
                    assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
         | 
| 65 | 
            +
                    # to -> batch * L, 1, D
         | 
| 66 | 
            +
                    hidden_states = hidden_states.reshape(batch * L, 1, D)
         | 
| 67 | 
            +
                    other_states = other_states.reshape(batch * L, 1, D)
         | 
| 68 | 
            +
                    hidden_states_catted = other_states
         | 
| 69 | 
            +
                    hidden_states = self.crosspond_attention(
         | 
| 70 | 
            +
                        hidden_states,
         | 
| 71 | 
            +
                        encoder_hidden_states=hidden_states_catted,
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
                    return hidden_states.reshape(batch, L, D)
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                def __call__(
         | 
| 76 | 
            +
                    self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
         | 
| 77 | 
            +
                    ref_dict: dict = None, mode=None, **kwargs
         | 
| 78 | 
            +
                ) -> Any:
         | 
| 79 | 
            +
                    if not self.enabled:
         | 
| 80 | 
            +
                        return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 81 | 
            +
                    if encoder_hidden_states is None:
         | 
| 82 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 83 | 
            +
                    assert ref_dict is not None
         | 
| 84 | 
            +
                    if (mode or self.mode) == 'extract':
         | 
| 85 | 
            +
                        ref_dict[self.name] = hidden_states
         | 
| 86 | 
            +
                        hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 87 | 
            +
                        if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
         | 
| 88 | 
            +
                            ref_dict[self.name] = hidden_states1
         | 
| 89 | 
            +
                        return hidden_states1
         | 
| 90 | 
            +
                    elif (mode or self.mode) == 'inject':
         | 
| 91 | 
            +
                        ref_state = ref_dict.pop(self.name)
         | 
| 92 | 
            +
                        if self.with_proj_in:
         | 
| 93 | 
            +
                            ref_state = self.in_linear(ref_state)
         | 
| 94 | 
            +
                        
         | 
| 95 | 
            +
                        B, L, D = ref_state.shape
         | 
| 96 | 
            +
                        if hidden_states.shape[0] == B:
         | 
| 97 | 
            +
                            modalities = 1
         | 
| 98 | 
            +
                            views = 1
         | 
| 99 | 
            +
                        else:
         | 
| 100 | 
            +
                            modalities = hidden_states.shape[0] // B // self.views
         | 
| 101 | 
            +
                            views = self.views
         | 
| 102 | 
            +
                        if self.pixel_wise_crosspond:
         | 
| 103 | 
            +
                            if self.crosspond_effect_on == "all":
         | 
| 104 | 
            +
                                ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
         | 
| 105 | 
            +
                                
         | 
| 106 | 
            +
                                if self.crosspond_chain_pos == "before":
         | 
| 107 | 
            +
                                    hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
         | 
| 108 | 
            +
                                    
         | 
| 109 | 
            +
                                hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 110 | 
            +
                                
         | 
| 111 | 
            +
                                if self.crosspond_chain_pos == "parralle":
         | 
| 112 | 
            +
                                    hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
         | 
| 113 | 
            +
                                    
         | 
| 114 | 
            +
                                if self.crosspond_chain_pos == "after":
         | 
| 115 | 
            +
                                    hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
         | 
| 116 | 
            +
                                return hidden_states1
         | 
| 117 | 
            +
                            else:
         | 
| 118 | 
            +
                                assert self.crosspond_effect_on == "first"
         | 
| 119 | 
            +
                                # hidden_states [B * modalities * views, L, D]
         | 
| 120 | 
            +
                                # ref_state [B, L, D]
         | 
| 121 | 
            +
                                ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])  # [B * modalities, L, D]
         | 
| 122 | 
            +
                                
         | 
| 123 | 
            +
                                def do_paritial_crosspond(hidden_states, ref_state):
         | 
| 124 | 
            +
                                    first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0]  # [B * modalities, L, D]
         | 
| 125 | 
            +
                                    hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
         | 
| 126 | 
            +
                                    hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
         | 
| 127 | 
            +
                                    hidden_states2_padded[:, 0] = hidden_states2
         | 
| 128 | 
            +
                                    hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
         | 
| 129 | 
            +
                                    return hidden_states2_padded
         | 
| 130 | 
            +
                                
         | 
| 131 | 
            +
                                if self.crosspond_chain_pos == "before":
         | 
| 132 | 
            +
                                    hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
         | 
| 133 | 
            +
                                
         | 
| 134 | 
            +
                                hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)    # [B * modalities * views, L, D]
         | 
| 135 | 
            +
                                if self.crosspond_chain_pos == "parralle":
         | 
| 136 | 
            +
                                    hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
         | 
| 137 | 
            +
                                if self.crosspond_chain_pos == "after":
         | 
| 138 | 
            +
                                    hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
         | 
| 139 | 
            +
                                return hidden_states1
         | 
| 140 | 
            +
                        elif self.simple_3d:
         | 
| 141 | 
            +
                            B, L, C = encoder_hidden_states.shape
         | 
| 142 | 
            +
                            mv = self.views
         | 
| 143 | 
            +
                            encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
         | 
| 144 | 
            +
                            ref_state = ref_state[:, None]
         | 
| 145 | 
            +
                            encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
         | 
| 146 | 
            +
                            encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
         | 
| 147 | 
            +
                            encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
         | 
| 148 | 
            +
                            return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 149 | 
            +
                        else:
         | 
| 150 | 
            +
                            ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
         | 
| 151 | 
            +
                            encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
         | 
| 152 | 
            +
                            return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 153 | 
            +
                    else:
         | 
| 154 | 
            +
                        raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
         | 
| 157 | 
            +
                return_dict = torch.nn.ModuleDict()
         | 
| 158 | 
            +
                proj_in_dim = kwargs.get('proj_in_dim', False)
         | 
| 159 | 
            +
                kwargs.pop('proj_in_dim', None)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 162 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 163 | 
            +
                        if "ref_unet" not in (sub_name + name):
         | 
| 164 | 
            +
                            recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    if isinstance(module, Attention):
         | 
| 167 | 
            +
                        new_processor = ExtraAttnProc(
         | 
| 168 | 
            +
                            chained_proc=module.get_processor(),
         | 
| 169 | 
            +
                            enabled=enable_filter(f"{name}.processor"),
         | 
| 170 | 
            +
                            name=f"{name}.processor",
         | 
| 171 | 
            +
                            proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
         | 
| 172 | 
            +
                            target_dim=module.cross_attention_dim,
         | 
| 173 | 
            +
                            **kwargs
         | 
| 174 | 
            +
                        )
         | 
| 175 | 
            +
                        module.set_processor(new_processor)
         | 
| 176 | 
            +
                        return_dict[f"{name}.processor".replace(".", "__")] = new_processor
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                for name, module in model.named_children():
         | 
| 179 | 
            +
                    recursive_add_processors(name, module)
         | 
| 180 | 
            +
                return return_dict
         | 
| 181 | 
            +
             | 
| 182 | 
            +
            def switch_extra_processor(model, enable_filter=lambda x:True):
         | 
| 183 | 
            +
                def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 184 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 185 | 
            +
                        recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if isinstance(module, ExtraAttnProc):
         | 
| 188 | 
            +
                        module.enabled = enable_filter(name)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                for name, module in model.named_children():
         | 
| 191 | 
            +
                    recursive_add_processors(name, module)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            class multiviewAttnProc(torch.nn.Module):
         | 
| 194 | 
            +
                def __init__(
         | 
| 195 | 
            +
                    self,
         | 
| 196 | 
            +
                    chained_proc,
         | 
| 197 | 
            +
                    enabled=False,
         | 
| 198 | 
            +
                    name=None,
         | 
| 199 | 
            +
                    hidden_states_dim=None,
         | 
| 200 | 
            +
                    chain_pos="parralle",     # before or parralle or after
         | 
| 201 | 
            +
                    num_modalities=1,
         | 
| 202 | 
            +
                    views=4,
         | 
| 203 | 
            +
                    base_img_size=64,
         | 
| 204 | 
            +
                ) -> None:
         | 
| 205 | 
            +
                    super().__init__()
         | 
| 206 | 
            +
                    self.enabled = enabled
         | 
| 207 | 
            +
                    self.chained_proc = chained_proc
         | 
| 208 | 
            +
                    self.name = name
         | 
| 209 | 
            +
                    self.hidden_states_dim = hidden_states_dim
         | 
| 210 | 
            +
                    self.num_modalities = num_modalities
         | 
| 211 | 
            +
                    self.views = views
         | 
| 212 | 
            +
                    self.base_img_size = base_img_size
         | 
| 213 | 
            +
                    self.chain_pos = chain_pos
         | 
| 214 | 
            +
                    self.diff_joint_attn = True
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def __call__(
         | 
| 217 | 
            +
                    self,
         | 
| 218 | 
            +
                    attn: Attention,
         | 
| 219 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 220 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 221 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 222 | 
            +
                    **kwargs
         | 
| 223 | 
            +
                ) -> torch.Tensor:
         | 
| 224 | 
            +
                    if not self.enabled:
         | 
| 225 | 
            +
                        return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 226 | 
            +
                    
         | 
| 227 | 
            +
                    B, L, C = hidden_states.shape
         | 
| 228 | 
            +
                    mv = self.views
         | 
| 229 | 
            +
                    hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
         | 
| 230 | 
            +
                    hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
         | 
| 231 | 
            +
                    return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
         | 
| 234 | 
            +
                return_dict = torch.nn.ModuleDict()
         | 
| 235 | 
            +
                def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 236 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 237 | 
            +
                        if "ref_unet" not in (sub_name + name):
         | 
| 238 | 
            +
                            recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if isinstance(module, Attention):
         | 
| 241 | 
            +
                        new_processor = multiviewAttnProc(
         | 
| 242 | 
            +
                            chained_proc=module.get_processor(),
         | 
| 243 | 
            +
                            enabled=enable_filter(f"{name}.processor"),
         | 
| 244 | 
            +
                            name=f"{name}.processor",
         | 
| 245 | 
            +
                            hidden_states_dim=module.inner_dim,
         | 
| 246 | 
            +
                            **kwargs
         | 
| 247 | 
            +
                        )
         | 
| 248 | 
            +
                        module.set_processor(new_processor)
         | 
| 249 | 
            +
                        return_dict[f"{name}.processor".replace(".", "__")] = new_processor
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                for name, module in model.named_children():
         | 
| 252 | 
            +
                    recursive_add_processors(name, module)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                return return_dict
         | 
| 255 | 
            +
             | 
| 256 | 
            +
            def switch_multiview_processor(model, enable_filter=lambda x:True):
         | 
| 257 | 
            +
                def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 258 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 259 | 
            +
                        recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    if isinstance(module, Attention):
         | 
| 262 | 
            +
                        processor = module.get_processor()
         | 
| 263 | 
            +
                        if isinstance(processor, multiviewAttnProc):
         | 
| 264 | 
            +
                            processor.enabled = enable_filter(f"{name}.processor")
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                for name, module in model.named_children():
         | 
| 267 | 
            +
                    recursive_add_processors(name, module)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            class NNModuleWrapper(torch.nn.Module):
         | 
| 270 | 
            +
                def __init__(self, module):
         | 
| 271 | 
            +
                    super().__init__()
         | 
| 272 | 
            +
                    self.module = module
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 275 | 
            +
                    return self.module(*args, **kwargs)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def __getattr__(self, name: str):
         | 
| 278 | 
            +
                    try:
         | 
| 279 | 
            +
                        return super().__getattr__(name)
         | 
| 280 | 
            +
                    except AttributeError:
         | 
| 281 | 
            +
                        return getattr(self.module, name)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            class AttnProcessorSwitch(torch.nn.Module):
         | 
| 284 | 
            +
                def __init__(
         | 
| 285 | 
            +
                    self,
         | 
| 286 | 
            +
                    proc_dict: dict,
         | 
| 287 | 
            +
                    enabled_proc="default",
         | 
| 288 | 
            +
                    name=None,
         | 
| 289 | 
            +
                    switch_name="default_switch",
         | 
| 290 | 
            +
                ):
         | 
| 291 | 
            +
                    super().__init__()
         | 
| 292 | 
            +
                    self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
         | 
| 293 | 
            +
                    self.enabled_proc = enabled_proc
         | 
| 294 | 
            +
                    self.name = name
         | 
| 295 | 
            +
                    self.switch_name = switch_name
         | 
| 296 | 
            +
                    self.choose_module(enabled_proc)
         | 
| 297 | 
            +
                
         | 
| 298 | 
            +
                def choose_module(self, enabled_proc):
         | 
| 299 | 
            +
                    self.enabled_proc = enabled_proc
         | 
| 300 | 
            +
                    assert enabled_proc in self.proc_dict.keys()
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def __call__(
         | 
| 303 | 
            +
                    self,
         | 
| 304 | 
            +
                    *args,
         | 
| 305 | 
            +
                    **kwargs
         | 
| 306 | 
            +
                ) -> torch.FloatTensor:
         | 
| 307 | 
            +
                    used_proc = self.proc_dict[self.enabled_proc]
         | 
| 308 | 
            +
                    return used_proc(*args, **kwargs)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
            def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
         | 
| 311 | 
            +
                return_dict = torch.nn.ModuleDict()
         | 
| 312 | 
            +
                def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 313 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 314 | 
            +
                        if "ref_unet" not in (sub_name + name):
         | 
| 315 | 
            +
                            recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if isinstance(module, Attention):
         | 
| 318 | 
            +
                        processor = module.get_processor()
         | 
| 319 | 
            +
                        if module_filter(processor):
         | 
| 320 | 
            +
                            proc_dict = switch_dict_fn(processor)
         | 
| 321 | 
            +
                            new_processor = AttnProcessorSwitch(
         | 
| 322 | 
            +
                                proc_dict=proc_dict,
         | 
| 323 | 
            +
                                enabled_proc=enabled_proc,
         | 
| 324 | 
            +
                                name=f"{name}.processor",
         | 
| 325 | 
            +
                                switch_name=switch_name,
         | 
| 326 | 
            +
                            )
         | 
| 327 | 
            +
                            module.set_processor(new_processor)
         | 
| 328 | 
            +
                            return_dict[f"{name}.processor".replace(".", "__")] = new_processor
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                for name, module in model.named_children():
         | 
| 331 | 
            +
                    recursive_add_processors(name, module)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                return return_dict
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
         | 
| 336 | 
            +
                def recursive_change_processors(name: str, module: torch.nn.Module):
         | 
| 337 | 
            +
                    for sub_name, child in module.named_children():
         | 
| 338 | 
            +
                        recursive_change_processors(f"{name}.{sub_name}", child)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    if isinstance(module, Attention):
         | 
| 341 | 
            +
                        processor = module.get_processor()
         | 
| 342 | 
            +
                        if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
         | 
| 343 | 
            +
                            processor.choose_module(enabled_proc)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                for name, module in model.named_children():
         | 
| 346 | 
            +
                    recursive_change_processors(name, module)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
            ########## Hack: Attention fix #############
         | 
| 349 | 
            +
            from diffusers.models.attention import Attention
         | 
| 350 | 
            +
             | 
| 351 | 
            +
            def forward(
         | 
| 352 | 
            +
                self,
         | 
| 353 | 
            +
                hidden_states: torch.FloatTensor,
         | 
| 354 | 
            +
                encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 355 | 
            +
                attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 356 | 
            +
                **cross_attention_kwargs,
         | 
| 357 | 
            +
            ) -> torch.Tensor:
         | 
| 358 | 
            +
                r"""
         | 
| 359 | 
            +
                The forward method of the `Attention` class.
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                Args:
         | 
| 362 | 
            +
                    hidden_states (`torch.Tensor`):
         | 
| 363 | 
            +
                        The hidden states of the query.
         | 
| 364 | 
            +
                    encoder_hidden_states (`torch.Tensor`, *optional*):
         | 
| 365 | 
            +
                        The hidden states of the encoder.
         | 
| 366 | 
            +
                    attention_mask (`torch.Tensor`, *optional*):
         | 
| 367 | 
            +
                        The attention mask to use. If `None`, no mask is applied.
         | 
| 368 | 
            +
                    **cross_attention_kwargs:
         | 
| 369 | 
            +
                        Additional keyword arguments to pass along to the cross attention.
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                Returns:
         | 
| 372 | 
            +
                    `torch.Tensor`: The output of the attention layer.
         | 
| 373 | 
            +
                """
         | 
| 374 | 
            +
                # The `Attention` class can call different attention processors / attention functions
         | 
| 375 | 
            +
                # here we simply pass along all tensors to the selected processor class
         | 
| 376 | 
            +
                # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         | 
| 377 | 
            +
                return self.processor(
         | 
| 378 | 
            +
                    self,
         | 
| 379 | 
            +
                    hidden_states,
         | 
| 380 | 
            +
                    encoder_hidden_states=encoder_hidden_states,
         | 
| 381 | 
            +
                    attention_mask=attention_mask,
         | 
| 382 | 
            +
                    **cross_attention_kwargs,
         | 
| 383 | 
            +
                )
         | 
| 384 | 
            +
             | 
| 385 | 
            +
            Attention.forward = forward
         | 
    	
        custum_3d_diffusion/custum_modules/unifield_processor.py
    ADDED
    
    | @@ -0,0 +1,459 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from types import FunctionType
         | 
| 2 | 
            +
            from typing import Any, Dict, List
         | 
| 3 | 
            +
            from diffusers import UNet2DConditionModel
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
         | 
| 6 | 
            +
            from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
         | 
| 7 | 
            +
            from dataclasses import dataclass, field
         | 
| 8 | 
            +
            from diffusers.loaders import IPAdapterMixin
         | 
| 9 | 
            +
            from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            @dataclass
         | 
| 12 | 
            +
            class AttnConfig:
         | 
| 13 | 
            +
                """        
         | 
| 14 | 
            +
                * CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
         | 
| 15 | 
            +
                * SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
         | 
| 16 | 
            +
                * Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
         | 
| 17 | 
            +
                * Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                For setups:
         | 
| 20 | 
            +
                    train_xxx_lr is implemented in the U-Net architecture.
         | 
| 21 | 
            +
                    enable_xxx_lora is implemented in the U-Net architecture.
         | 
| 22 | 
            +
                    enable_xxx_ip is implemented in the processor and U-Net architecture.
         | 
| 23 | 
            +
                    enable_xxx_ref_proj_in is implemented in the processor.
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                latent_size: int = 64
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                train_lr: float = 0
         | 
| 28 | 
            +
                # for cross attention
         | 
| 29 | 
            +
                # 0 learning rate for not training
         | 
| 30 | 
            +
                train_cross_attn_lr: float = 0
         | 
| 31 | 
            +
                train_cross_attn_lora_lr: float = 0       
         | 
| 32 | 
            +
                train_cross_attn_ip_lr: float = 0      # 0 for not trained
         | 
| 33 | 
            +
                init_cross_attn_lora: bool = False
         | 
| 34 | 
            +
                enable_cross_attn_lora: bool = False
         | 
| 35 | 
            +
                init_cross_attn_ip: bool = False
         | 
| 36 | 
            +
                enable_cross_attn_ip: bool = False
         | 
| 37 | 
            +
                cross_attn_lora_rank: int = 64        # 0 for not enabled
         | 
| 38 | 
            +
                cross_attn_lora_only_kv: bool = False
         | 
| 39 | 
            +
                ipadapter_pretrained_name: str = "h94/IP-Adapter"
         | 
| 40 | 
            +
                ipadapter_subfolder_name: str = "models"
         | 
| 41 | 
            +
                ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
         | 
| 42 | 
            +
                ipadapter_effect_on: str = "all"    # all, first
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # for self attention
         | 
| 45 | 
            +
                train_self_attn_lr: float = 0
         | 
| 46 | 
            +
                train_self_attn_lora_lr: float = 0
         | 
| 47 | 
            +
                init_self_attn_lora: bool = False
         | 
| 48 | 
            +
                enable_self_attn_lora: bool = False
         | 
| 49 | 
            +
                self_attn_lora_rank: int = 64
         | 
| 50 | 
            +
                self_attn_lora_only_kv: bool = False
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                train_self_attn_ref_lr: float = 0
         | 
| 53 | 
            +
                train_ref_unet_lr: float = 0
         | 
| 54 | 
            +
                init_self_attn_ref: bool = False
         | 
| 55 | 
            +
                enable_self_attn_ref: bool = False      
         | 
| 56 | 
            +
                self_attn_ref_other_model_name: str = ""
         | 
| 57 | 
            +
                self_attn_ref_position: str = "attn1"
         | 
| 58 | 
            +
                self_attn_ref_pixel_wise_crosspond: bool = False    # enable pixel_wise_crosspond in refattn
         | 
| 59 | 
            +
                self_attn_ref_chain_pos: str = "parralle"           # before or parralle or after
         | 
| 60 | 
            +
                self_attn_ref_effect_on: str = "all"                # all or first, for _crosspond attn
         | 
| 61 | 
            +
                self_attn_ref_zero_init: bool = True
         | 
| 62 | 
            +
                use_simple3d_attn: bool = False
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                # for multiview attention
         | 
| 65 | 
            +
                init_multiview_attn: bool = False
         | 
| 66 | 
            +
                enable_multiview_attn: bool = False
         | 
| 67 | 
            +
                multiview_attn_position: str = "attn1"
         | 
| 68 | 
            +
                multiview_chain_pose: str = "parralle"             # before or parralle or after
         | 
| 69 | 
            +
                num_modalities: int = 1
         | 
| 70 | 
            +
                use_mv_joint_attn: bool = False
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                # for unet
         | 
| 73 | 
            +
                init_unet_path: str = "runwayml/stable-diffusion-v1-5"
         | 
| 74 | 
            +
                init_num_cls_label: int = 0                         # for initialize
         | 
| 75 | 
            +
                cls_labels: List[int] = field(default_factory=lambda: [])
         | 
| 76 | 
            +
                cls_label_type: str = "embedding"
         | 
| 77 | 
            +
                cat_condition: bool = False                         # cat condition to input
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            class Configurable:
         | 
| 80 | 
            +
                attn_config: AttnConfig
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def set_config(self, attn_config: AttnConfig):
         | 
| 83 | 
            +
                    raise NotImplementedError()
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                def update_config(self, attn_config: AttnConfig):
         | 
| 86 | 
            +
                    self.attn_config = attn_config
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                def do_set_config(self, attn_config: AttnConfig):
         | 
| 89 | 
            +
                    self.set_config(attn_config)
         | 
| 90 | 
            +
                    for name, module in self.named_modules():
         | 
| 91 | 
            +
                        if isinstance(module, Configurable):
         | 
| 92 | 
            +
                            if hasattr(module, "do_set_config"):
         | 
| 93 | 
            +
                                module.do_set_config(attn_config)
         | 
| 94 | 
            +
                            else:
         | 
| 95 | 
            +
                                print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
         | 
| 96 | 
            +
                                module.attn_config = attn_config
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def do_update_config(self, attn_config: AttnConfig):
         | 
| 99 | 
            +
                    self.update_config(attn_config)
         | 
| 100 | 
            +
                    for name, module in self.named_modules():
         | 
| 101 | 
            +
                        if isinstance(module, Configurable):
         | 
| 102 | 
            +
                            if hasattr(module, "do_update_config"):
         | 
| 103 | 
            +
                                module.do_update_config(attn_config)
         | 
| 104 | 
            +
                            else:
         | 
| 105 | 
            +
                                print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
         | 
| 106 | 
            +
                                module.attn_config = attn_config
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            from diffusers import ModelMixin  # Must import ModelMixin for CompiledUNet
         | 
| 109 | 
            +
            class UnifieldWrappedUNet(UNet2DConditionModel):
         | 
| 110 | 
            +
                forward_hook: FunctionType
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 113 | 
            +
                    if hasattr(self, 'forward_hook'):
         | 
| 114 | 
            +
                        return self.forward_hook(super().forward, *args, **kwargs)
         | 
| 115 | 
            +
                    return super().forward(*args, **kwargs)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
         | 
| 119 | 
            +
                unet: UNet2DConditionModel
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                cls_embedding_param_dict = {}
         | 
| 122 | 
            +
                cross_attn_lora_param_dict = {}
         | 
| 123 | 
            +
                self_attn_lora_param_dict = {}
         | 
| 124 | 
            +
                cross_attn_param_dict = {}
         | 
| 125 | 
            +
                self_attn_param_dict = {}
         | 
| 126 | 
            +
                ipadapter_param_dict = {}
         | 
| 127 | 
            +
                ref_attn_param_dict = {}
         | 
| 128 | 
            +
                ref_unet_param_dict = {}
         | 
| 129 | 
            +
                multiview_attn_param_dict = {}
         | 
| 130 | 
            +
                other_param_dict = {}
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
                rev_param_name_mapping = {}
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                class_labels = []
         | 
| 135 | 
            +
                def set_class_labels(self, class_labels: torch.Tensor):
         | 
| 136 | 
            +
                    if self.attn_config.init_num_cls_label != 0:
         | 
| 137 | 
            +
                        self.class_labels = class_labels.to(self.unet.device).long()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
         | 
| 140 | 
            +
                    super().__init__()
         | 
| 141 | 
            +
                    self.weight_dtype = weight_dtype
         | 
| 142 | 
            +
                    self.set_config(init_config)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def enable_xformers_memory_efficient_attention(self):
         | 
| 145 | 
            +
                    self.unet.enable_xformers_memory_efficient_attention
         | 
| 146 | 
            +
                    def recursive_add_processors(name: str, module: torch.nn.Module):
         | 
| 147 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 148 | 
            +
                            recursive_add_processors(f"{name}.{sub_name}", child)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        if isinstance(module, Attention):
         | 
| 151 | 
            +
                            if hasattr(module, 'xformers_not_supported'):
         | 
| 152 | 
            +
                                return
         | 
| 153 | 
            +
                            old_processor = module.get_processor()
         | 
| 154 | 
            +
                            if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
         | 
| 155 | 
            +
                                module.set_use_memory_efficient_attention_xformers(True)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    for name, module in self.unet.named_children():
         | 
| 158 | 
            +
                        recursive_add_processors(name, module)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def __getattr__(self, name: str) -> Any:
         | 
| 161 | 
            +
                    try:
         | 
| 162 | 
            +
                        return super().__getattr__(name)
         | 
| 163 | 
            +
                    except AttributeError:
         | 
| 164 | 
            +
                        return getattr(self.unet, name)
         | 
| 165 | 
            +
                
         | 
| 166 | 
            +
                # --- for IPAdapterMixin
         | 
| 167 | 
            +
                
         | 
| 168 | 
            +
                def register_modules(self, **kwargs):
         | 
| 169 | 
            +
                    for name, module in kwargs.items():
         | 
| 170 | 
            +
                        # set models
         | 
| 171 | 
            +
                        setattr(self, name, module)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def register_to_config(self, **kwargs):
         | 
| 174 | 
            +
                    pass
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def unload_ip_adapter(self):
         | 
| 177 | 
            +
                    raise NotImplementedError()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                # --- for Configurable
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def get_refunet(self):
         | 
| 182 | 
            +
                    if self.attn_config.self_attn_ref_other_model_name == "self":
         | 
| 183 | 
            +
                        return self.unet
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        return self.unet.ref_unet
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def set_config(self, attn_config: AttnConfig):
         | 
| 188 | 
            +
                    self.attn_config = attn_config
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    unet_type = UnifieldWrappedUNet
         | 
| 191 | 
            +
                    # class_embed_type = "projection" for 'camera'
         | 
| 192 | 
            +
                    # class_embed_type = None for 'embedding'
         | 
| 193 | 
            +
                    unet_kwargs = {}
         | 
| 194 | 
            +
                    if attn_config.init_num_cls_label > 0:
         | 
| 195 | 
            +
                        if attn_config.cls_label_type == "embedding":
         | 
| 196 | 
            +
                            unet_kwargs = {
         | 
| 197 | 
            +
                                "num_class_embeds": attn_config.init_num_cls_label, 
         | 
| 198 | 
            +
                                "device_map": None, 
         | 
| 199 | 
            +
                                "low_cpu_mem_usage": False,
         | 
| 200 | 
            +
                                "class_embed_type": None,
         | 
| 201 | 
            +
                            }
         | 
| 202 | 
            +
                        else:
         | 
| 203 | 
            +
                            raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
         | 
| 204 | 
            +
                    
         | 
| 205 | 
            +
                    self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
         | 
| 206 | 
            +
                        attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype, 
         | 
| 207 | 
            +
                        **unet_kwargs
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
                    assert isinstance(self.unet, UnifieldWrappedUNet)
         | 
| 210 | 
            +
                    self.unet.forward_hook = self.unet_forward_hook
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    if self.attn_config.cat_condition:
         | 
| 213 | 
            +
                        # double in_channels
         | 
| 214 | 
            +
                        if self.unet.config.in_channels != 8:
         | 
| 215 | 
            +
                            self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
         | 
| 216 | 
            +
                            # repeate unet.conv_in weight twice
         | 
| 217 | 
            +
                            doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
         | 
| 218 | 
            +
                            doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
         | 
| 219 | 
            +
                            doubled_conv_in.bias.data = self.unet.conv_in.bias.data
         | 
| 220 | 
            +
                            self.unet.conv_in = doubled_conv_in
         | 
| 221 | 
            +
                    
         | 
| 222 | 
            +
                    used_param_ids = set()
         | 
| 223 | 
            +
                    
         | 
| 224 | 
            +
                    if attn_config.init_cross_attn_lora:
         | 
| 225 | 
            +
                        # setup lora
         | 
| 226 | 
            +
                        from peft import LoraConfig
         | 
| 227 | 
            +
                        from peft.utils import get_peft_model_state_dict
         | 
| 228 | 
            +
                        if attn_config.cross_attn_lora_only_kv:
         | 
| 229 | 
            +
                            target_modules=["attn2.to_k", "attn2.to_v"]
         | 
| 230 | 
            +
                        else:
         | 
| 231 | 
            +
                            target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
         | 
| 232 | 
            +
                        lora_config: LoraConfig = LoraConfig(
         | 
| 233 | 
            +
                            r=attn_config.cross_attn_lora_rank,
         | 
| 234 | 
            +
                            lora_alpha=attn_config.cross_attn_lora_rank,
         | 
| 235 | 
            +
                            init_lora_weights="gaussian",
         | 
| 236 | 
            +
                            target_modules=target_modules,
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
                        adapter_name="cross_attn_lora"
         | 
| 239 | 
            +
                        self.unet.add_adapter(lora_config, adapter_name=adapter_name)
         | 
| 240 | 
            +
                        # update cross_attn_lora_param_dict
         | 
| 241 | 
            +
                        self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
         | 
| 242 | 
            +
                        used_param_ids.update(self.cross_attn_lora_param_dict.keys())
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if attn_config.init_self_attn_lora:
         | 
| 245 | 
            +
                        # setup lora
         | 
| 246 | 
            +
                        from peft import LoraConfig
         | 
| 247 | 
            +
                        if attn_config.self_attn_lora_only_kv:
         | 
| 248 | 
            +
                            target_modules=["attn1.to_k", "attn1.to_v"]
         | 
| 249 | 
            +
                        else:
         | 
| 250 | 
            +
                            target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
         | 
| 251 | 
            +
                        lora_config: LoraConfig = LoraConfig(
         | 
| 252 | 
            +
                            r=attn_config.self_attn_lora_rank,
         | 
| 253 | 
            +
                            lora_alpha=attn_config.self_attn_lora_rank,
         | 
| 254 | 
            +
                            init_lora_weights="gaussian",
         | 
| 255 | 
            +
                            target_modules=target_modules,
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                        adapter_name="self_attn_lora"
         | 
| 258 | 
            +
                        self.unet.add_adapter(lora_config, adapter_name=adapter_name)
         | 
| 259 | 
            +
                        # update cross_self_lora_param_dict
         | 
| 260 | 
            +
                        self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
         | 
| 261 | 
            +
                        used_param_ids.update(self.self_attn_lora_param_dict.keys())
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    if attn_config.init_num_cls_label != 0:
         | 
| 264 | 
            +
                        self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
         | 
| 265 | 
            +
                        used_param_ids.update(self.cls_embedding_param_dict.keys())
         | 
| 266 | 
            +
                        self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
         | 
| 267 | 
            +
                    
         | 
| 268 | 
            +
                    if attn_config.init_cross_attn_ip:
         | 
| 269 | 
            +
                        self.image_encoder = None
         | 
| 270 | 
            +
                        # setup ipadapter
         | 
| 271 | 
            +
                        self.load_ip_adapter(
         | 
| 272 | 
            +
                            attn_config.ipadapter_pretrained_name,
         | 
| 273 | 
            +
                            subfolder=attn_config.ipadapter_subfolder_name,
         | 
| 274 | 
            +
                            weight_name=attn_config.ipadapter_weight_name
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                        # warp ip_adapter_attn_proc with switch
         | 
| 277 | 
            +
                        from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
         | 
| 278 | 
            +
                        add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
         | 
| 279 | 
            +
                        # update ipadapter_param_dict
         | 
| 280 | 
            +
                        # weights are in attention processors and unet.encoder_hid_proj
         | 
| 281 | 
            +
                        self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
         | 
| 282 | 
            +
                        used_param_ids.update(self.ipadapter_param_dict.keys())
         | 
| 283 | 
            +
                        print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
         | 
| 284 | 
            +
                        for name, processor in self.unet.attn_processors.items():
         | 
| 285 | 
            +
                            if hasattr(processor, "to_k_ip"):
         | 
| 286 | 
            +
                                self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
         | 
| 287 | 
            +
                        print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    ref_unet = None
         | 
| 290 | 
            +
                    if attn_config.init_self_attn_ref:
         | 
| 291 | 
            +
                        # setup reference attention processor
         | 
| 292 | 
            +
                        if attn_config.self_attn_ref_other_model_name == "self":
         | 
| 293 | 
            +
                            raise NotImplementedError("self reference is not fully implemented")
         | 
| 294 | 
            +
                        else:
         | 
| 295 | 
            +
                            ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
         | 
| 296 | 
            +
                                attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
         | 
| 297 | 
            +
                            )
         | 
| 298 | 
            +
                            ref_unet.to(self.unet.device)
         | 
| 299 | 
            +
                            if self.attn_config.train_ref_unet_lr == 0:
         | 
| 300 | 
            +
                                ref_unet.eval()
         | 
| 301 | 
            +
                                ref_unet.requires_grad_(False)
         | 
| 302 | 
            +
                            else:
         | 
| 303 | 
            +
                                ref_unet.train()
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                            add_extra_processor(
         | 
| 306 | 
            +
                                model=ref_unet, 
         | 
| 307 | 
            +
                                enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"), 
         | 
| 308 | 
            +
                                mode='extract',
         | 
| 309 | 
            +
                                with_proj_in=False,
         | 
| 310 | 
            +
                                pixel_wise_crosspond=False,
         | 
| 311 | 
            +
                            )
         | 
| 312 | 
            +
                            # NOTE: Here require cross_attention_dim in two unet's self attention should be the same
         | 
| 313 | 
            +
                            processor_dict = add_extra_processor(
         | 
| 314 | 
            +
                                model=self.unet,
         | 
| 315 | 
            +
                                enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
         | 
| 316 | 
            +
                                mode='inject',
         | 
| 317 | 
            +
                                with_proj_in=False,
         | 
| 318 | 
            +
                                pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
         | 
| 319 | 
            +
                                crosspond_effect_on=attn_config.self_attn_ref_effect_on,
         | 
| 320 | 
            +
                                crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
         | 
| 321 | 
            +
                                simple_3d=attn_config.use_simple3d_attn,
         | 
| 322 | 
            +
                            )
         | 
| 323 | 
            +
                            self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
         | 
| 324 | 
            +
                            if attn_config.self_attn_ref_chain_pos != "after":
         | 
| 325 | 
            +
                                # pop untrainable paramters
         | 
| 326 | 
            +
                                for name, param in ref_unet.named_parameters():
         | 
| 327 | 
            +
                                    if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
         | 
| 328 | 
            +
                                        self.ref_unet_param_dict.pop(id(param))
         | 
| 329 | 
            +
                            used_param_ids.update(self.ref_unet_param_dict.keys())
         | 
| 330 | 
            +
                        # update ref_attn_param_dict
         | 
| 331 | 
            +
                        self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
         | 
| 332 | 
            +
                        used_param_ids.update(self.ref_attn_param_dict.keys())
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    if attn_config.init_multiview_attn:
         | 
| 335 | 
            +
                        processor_dict = add_multiview_processor(
         | 
| 336 | 
            +
                            model = self.unet, 
         | 
| 337 | 
            +
                            enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
         | 
| 338 | 
            +
                            num_modalities = attn_config.num_modalities,    
         | 
| 339 | 
            +
                            base_img_size = attn_config.latent_size,      
         | 
| 340 | 
            +
                            chain_pos = attn_config.multiview_chain_pose,
         | 
| 341 | 
            +
                        )
         | 
| 342 | 
            +
                        # update multiview_attn_param_dict
         | 
| 343 | 
            +
                        self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
         | 
| 344 | 
            +
                        used_param_ids.update(self.multiview_attn_param_dict.keys())
         | 
| 345 | 
            +
                    
         | 
| 346 | 
            +
                    # initialize cross_attn_param_dict parameters
         | 
| 347 | 
            +
                    self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
         | 
| 348 | 
            +
                    used_param_ids.update(self.cross_attn_param_dict.keys())
         | 
| 349 | 
            +
                    
         | 
| 350 | 
            +
                    # initialize self_attn_param_dict parameters
         | 
| 351 | 
            +
                    self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
         | 
| 352 | 
            +
                    used_param_ids.update(self.self_attn_param_dict.keys())
         | 
| 353 | 
            +
                    
         | 
| 354 | 
            +
                    # initialize other_param_dict parameters
         | 
| 355 | 
            +
                    self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
         | 
| 356 | 
            +
                    
         | 
| 357 | 
            +
                    if ref_unet is not None:
         | 
| 358 | 
            +
                        self.unet.ref_unet = ref_unet
         | 
| 359 | 
            +
                        
         | 
| 360 | 
            +
                    self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
         | 
| 361 | 
            +
                    
         | 
| 362 | 
            +
                    self.update_config(attn_config, force_update=True)
         | 
| 363 | 
            +
                    return self.unet
         | 
| 364 | 
            +
                
         | 
| 365 | 
            +
                _attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
         | 
| 366 | 
            +
                
         | 
| 367 | 
            +
                def update_config(self, attn_config: AttnConfig, force_update=False):
         | 
| 368 | 
            +
                    assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    need_to_update = False
         | 
| 371 | 
            +
                    # update cls_labels
         | 
| 372 | 
            +
                    for key in self._attn_keys_to_update:
         | 
| 373 | 
            +
                        if getattr(self.attn_config, key) != getattr(attn_config, key):
         | 
| 374 | 
            +
                            need_to_update = True
         | 
| 375 | 
            +
                            break
         | 
| 376 | 
            +
                    if not force_update and not need_to_update:
         | 
| 377 | 
            +
                        return
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
         | 
| 380 | 
            +
                    
         | 
| 381 | 
            +
                    # setup loras
         | 
| 382 | 
            +
                    if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
         | 
| 383 | 
            +
                        if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
         | 
| 384 | 
            +
                            cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
         | 
| 385 | 
            +
                            self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
         | 
| 386 | 
            +
                            self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
         | 
| 387 | 
            +
                        else:
         | 
| 388 | 
            +
                            self.unet.disable_adapters()
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # setup ipadapter
         | 
| 391 | 
            +
                    if self.attn_config.init_cross_attn_ip:
         | 
| 392 | 
            +
                        if attn_config.enable_cross_attn_ip:
         | 
| 393 | 
            +
                            change_switch(self.unet, "ipadapter_switch", "ipadapter")
         | 
| 394 | 
            +
                        else:
         | 
| 395 | 
            +
                            change_switch(self.unet, "ipadapter_switch", "default")
         | 
| 396 | 
            +
                        
         | 
| 397 | 
            +
                    # setup reference attention processor
         | 
| 398 | 
            +
                    if self.attn_config.init_self_attn_ref:
         | 
| 399 | 
            +
                        if attn_config.enable_self_attn_ref:
         | 
| 400 | 
            +
                            switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
         | 
| 401 | 
            +
                        else:
         | 
| 402 | 
            +
                            switch_extra_processor(self.unet, enable_filter=lambda name: False)
         | 
| 403 | 
            +
                    
         | 
| 404 | 
            +
                    # setup multiview attention processor
         | 
| 405 | 
            +
                    if self.attn_config.init_multiview_attn:
         | 
| 406 | 
            +
                        if attn_config.enable_multiview_attn:
         | 
| 407 | 
            +
                            switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
         | 
| 408 | 
            +
                        else:
         | 
| 409 | 
            +
                            switch_multiview_processor(self.unet, enable_filter=lambda name: False)
         | 
| 410 | 
            +
                    
         | 
| 411 | 
            +
                    # update cls_labels
         | 
| 412 | 
            +
                    for key in self._attn_keys_to_update:
         | 
| 413 | 
            +
                        setattr(self.attn_config, key, getattr(attn_config, key))
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
         | 
| 416 | 
            +
                    if class_labels is None and len(self.class_labels) > 0:
         | 
| 417 | 
            +
                        class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
         | 
| 418 | 
            +
                    elif self.attn_config.init_num_cls_label != 0:
         | 
| 419 | 
            +
                        assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
         | 
| 420 | 
            +
                    if class_labels is not None:
         | 
| 421 | 
            +
                        if self.attn_config.cls_label_type == "embedding":
         | 
| 422 | 
            +
                            pass
         | 
| 423 | 
            +
                        else:
         | 
| 424 | 
            +
                            raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
         | 
| 425 | 
            +
                    if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
         | 
| 426 | 
            +
                        # NOTE: extra step, extract condition
         | 
| 427 | 
            +
                        ref_dict = {}
         | 
| 428 | 
            +
                        ref_unet = self.get_refunet().to(sample.device)
         | 
| 429 | 
            +
                        assert condition_latents is not None
         | 
| 430 | 
            +
                        if self.attn_config.self_attn_ref_other_model_name == "self":
         | 
| 431 | 
            +
                            raise NotImplementedError()
         | 
| 432 | 
            +
                        else:
         | 
| 433 | 
            +
                            with torch.no_grad():
         | 
| 434 | 
            +
                                cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
         | 
| 435 | 
            +
                                if timestep.dim() == 0:
         | 
| 436 | 
            +
                                    cond_timestep = timestep
         | 
| 437 | 
            +
                                else:
         | 
| 438 | 
            +
                                    cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
         | 
| 439 | 
            +
                            ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states,  cross_attention_kwargs=dict(ref_dict=ref_dict))
         | 
| 440 | 
            +
                        # NOTE: extra step, inject condition
         | 
| 441 | 
            +
                        # Predict the noise residual and compute loss
         | 
| 442 | 
            +
                        if cross_attention_kwargs is None:
         | 
| 443 | 
            +
                            cross_attention_kwargs = {}
         | 
| 444 | 
            +
                        cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
         | 
| 445 | 
            +
                    elif condition_latents is not None:
         | 
| 446 | 
            +
                        if not hasattr(self, 'condition_latents_raised'):
         | 
| 447 | 
            +
                            print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
         | 
| 448 | 
            +
                            self.condition_latents_raised = True
         | 
| 449 | 
            +
                    
         | 
| 450 | 
            +
                    if self.attn_config.init_cross_attn_ip:
         | 
| 451 | 
            +
                        raise NotImplementedError()
         | 
| 452 | 
            +
                    
         | 
| 453 | 
            +
                    if self.attn_config.cat_condition:
         | 
| 454 | 
            +
                        assert condition_latents is not None
         | 
| 455 | 
            +
                        B = condition_latents.shape[0]
         | 
| 456 | 
            +
                        cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
         | 
| 457 | 
            +
                        sample = torch.cat([sample, cat_latents], dim=1)
         | 
| 458 | 
            +
                        
         | 
| 459 | 
            +
                    return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
         | 
    	
        custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py
    ADDED
    
    | @@ -0,0 +1,298 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # modified by Wuvin
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple, Union
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
         | 
| 23 | 
            +
            from diffusers.schedulers import KarrasDiffusionSchedulers
         | 
| 24 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
         | 
| 25 | 
            +
            from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
         | 
| 26 | 
            +
            from PIL import Image
         | 
| 27 | 
            +
            from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class StableDiffusionImageCustomPipeline(
         | 
| 32 | 
            +
                StableDiffusionImageVariationPipeline
         | 
| 33 | 
            +
            ):       
         | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    vae: AutoencoderKL,
         | 
| 37 | 
            +
                    image_encoder: CLIPVisionModelWithProjection,
         | 
| 38 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 39 | 
            +
                    scheduler: KarrasDiffusionSchedulers,
         | 
| 40 | 
            +
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 41 | 
            +
                    feature_extractor: CLIPImageProcessor,
         | 
| 42 | 
            +
                    requires_safety_checker: bool = True,
         | 
| 43 | 
            +
                    latents_offset=None,
         | 
| 44 | 
            +
                    noisy_cond_latents=False,
         | 
| 45 | 
            +
                ):
         | 
| 46 | 
            +
                    super().__init__(
         | 
| 47 | 
            +
                        vae=vae,
         | 
| 48 | 
            +
                        image_encoder=image_encoder,
         | 
| 49 | 
            +
                        unet=unet,
         | 
| 50 | 
            +
                        scheduler=scheduler,
         | 
| 51 | 
            +
                        safety_checker=safety_checker,
         | 
| 52 | 
            +
                        feature_extractor=feature_extractor,
         | 
| 53 | 
            +
                        requires_safety_checker=requires_safety_checker
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    latents_offset = tuple(latents_offset) if latents_offset is not None else None
         | 
| 56 | 
            +
                    self.latents_offset = latents_offset
         | 
| 57 | 
            +
                    if latents_offset is not None:
         | 
| 58 | 
            +
                        self.register_to_config(latents_offset=latents_offset)
         | 
| 59 | 
            +
                    self.noisy_cond_latents = noisy_cond_latents
         | 
| 60 | 
            +
                    self.register_to_config(noisy_cond_latents=noisy_cond_latents)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def encode_latents(self, image, device, dtype, height, width):
         | 
| 63 | 
            +
                    # support batchsize > 1
         | 
| 64 | 
            +
                    if isinstance(image, Image.Image):
         | 
| 65 | 
            +
                        image = [image]
         | 
| 66 | 
            +
                    image = [img.convert("RGB") for img in image]
         | 
| 67 | 
            +
                    images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
         | 
| 68 | 
            +
                    latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
         | 
| 69 | 
            +
                    if self.latents_offset is not None:
         | 
| 70 | 
            +
                        return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        return latents
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
         | 
| 75 | 
            +
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if not isinstance(image, torch.Tensor):
         | 
| 78 | 
            +
                        image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    image = image.to(device=device, dtype=dtype)
         | 
| 81 | 
            +
                    image_embeddings = self.image_encoder(image).image_embeds
         | 
| 82 | 
            +
                    image_embeddings = image_embeddings.unsqueeze(1)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # duplicate image embeddings for each generation per prompt, using mps friendly method
         | 
| 85 | 
            +
                    bs_embed, seq_len, _ = image_embeddings.shape
         | 
| 86 | 
            +
                    image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 87 | 
            +
                    image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if do_classifier_free_guidance:
         | 
| 90 | 
            +
                        # NOTE: the same as original code
         | 
| 91 | 
            +
                        negative_prompt_embeds = torch.zeros_like(image_embeddings)
         | 
| 92 | 
            +
                        # For classifier free guidance, we need to do two forward passes.
         | 
| 93 | 
            +
                        # Here we concatenate the unconditional and text embeddings into a single batch
         | 
| 94 | 
            +
                        # to avoid doing two forward passes
         | 
| 95 | 
            +
                        image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    return image_embeddings
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                @torch.no_grad()
         | 
| 100 | 
            +
                def __call__(
         | 
| 101 | 
            +
                    self,
         | 
| 102 | 
            +
                    image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
         | 
| 103 | 
            +
                    height: Optional[int] = 1024,
         | 
| 104 | 
            +
                    width: Optional[int] = 1024,
         | 
| 105 | 
            +
                    height_cond: Optional[int] = 512,
         | 
| 106 | 
            +
                    width_cond: Optional[int] = 512,
         | 
| 107 | 
            +
                    num_inference_steps: int = 50,
         | 
| 108 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 109 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 110 | 
            +
                    eta: float = 0.0,
         | 
| 111 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 112 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 113 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 114 | 
            +
                    return_dict: bool = True,
         | 
| 115 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 116 | 
            +
                    callback_steps: int = 1,
         | 
| 117 | 
            +
                    upper_left_feature: bool = False,
         | 
| 118 | 
            +
                ):
         | 
| 119 | 
            +
                    r"""
         | 
| 120 | 
            +
                    The call function to the pipeline for generation.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    Args:
         | 
| 123 | 
            +
                        image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
         | 
| 124 | 
            +
                            Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
         | 
| 125 | 
            +
                            [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
         | 
| 126 | 
            +
                        height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
         | 
| 127 | 
            +
                            The height in pixels of the generated image.
         | 
| 128 | 
            +
                        width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
         | 
| 129 | 
            +
                            The width in pixels of the generated image.
         | 
| 130 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 131 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 132 | 
            +
                            expense of slower inference. This parameter is modulated by `strength`.
         | 
| 133 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 134 | 
            +
                            A higher guidance scale value encourages the model to generate images closely linked to the text
         | 
| 135 | 
            +
                            `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
         | 
| 136 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 137 | 
            +
                            The number of images to generate per prompt.
         | 
| 138 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 139 | 
            +
                            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
         | 
| 140 | 
            +
                            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
         | 
| 141 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 142 | 
            +
                            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
         | 
| 143 | 
            +
                            generation deterministic.
         | 
| 144 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 145 | 
            +
                            Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 146 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 147 | 
            +
                            tensor is generated by sampling using the supplied random `generator`.
         | 
| 148 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 149 | 
            +
                            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
         | 
| 150 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 151 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 152 | 
            +
                            plain tuple.
         | 
| 153 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 154 | 
            +
                            A function that calls every `callback_steps` steps during inference. The function is called with the
         | 
| 155 | 
            +
                            following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 156 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 157 | 
            +
                            The frequency at which the `callback` function is called. If not specified, the callback is called at
         | 
| 158 | 
            +
                            every step.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    Returns:
         | 
| 161 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 162 | 
            +
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 163 | 
            +
                            otherwise a `tuple` is returned where the first element is a list with the generated images and the
         | 
| 164 | 
            +
                            second element is a list of `bool`s indicating whether the corresponding generated image contains
         | 
| 165 | 
            +
                            "not-safe-for-work" (nsfw) content.
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    Examples:
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    ```py
         | 
| 170 | 
            +
                    from diffusers import StableDiffusionImageVariationPipeline
         | 
| 171 | 
            +
                    from PIL import Image
         | 
| 172 | 
            +
                    from io import BytesIO
         | 
| 173 | 
            +
                    import requests
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    pipe = StableDiffusionImageVariationPipeline.from_pretrained(
         | 
| 176 | 
            +
                        "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
                    pipe = pipe.to("cuda")
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    response = requests.get(url)
         | 
| 183 | 
            +
                    image = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
         | 
| 186 | 
            +
                    out["images"][0].save("result.jpg")
         | 
| 187 | 
            +
                    ```
         | 
| 188 | 
            +
                    """
         | 
| 189 | 
            +
                    # 0. Default height and width to unet
         | 
| 190 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 191 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 194 | 
            +
                    self.check_inputs(image, height, width, callback_steps)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # 2. Define call parameters
         | 
| 197 | 
            +
                    if isinstance(image, Image.Image):
         | 
| 198 | 
            +
                        batch_size = 1
         | 
| 199 | 
            +
                    elif isinstance(image, list):
         | 
| 200 | 
            +
                        batch_size = len(image)
         | 
| 201 | 
            +
                    else:
         | 
| 202 | 
            +
                        batch_size = image.shape[0]
         | 
| 203 | 
            +
                    device = self._execution_device
         | 
| 204 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 205 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 206 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 207 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # 3. Encode input image
         | 
| 210 | 
            +
                    if isinstance(image, Image.Image) and upper_left_feature:
         | 
| 211 | 
            +
                        # only use the first one of four images
         | 
| 212 | 
            +
                        emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
         | 
| 213 | 
            +
                    else:
         | 
| 214 | 
            +
                        emb_image = image
         | 
| 215 | 
            +
                    
         | 
| 216 | 
            +
                    image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
         | 
| 217 | 
            +
                    cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # 4. Prepare timesteps
         | 
| 220 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 221 | 
            +
                    timesteps = self.scheduler.timesteps
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # 5. Prepare latent variables
         | 
| 224 | 
            +
                    num_channels_latents = self.unet.config.out_channels
         | 
| 225 | 
            +
                    latents = self.prepare_latents(
         | 
| 226 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 227 | 
            +
                        num_channels_latents,
         | 
| 228 | 
            +
                        height,
         | 
| 229 | 
            +
                        width,
         | 
| 230 | 
            +
                        image_embeddings.dtype,
         | 
| 231 | 
            +
                        device,
         | 
| 232 | 
            +
                        generator,
         | 
| 233 | 
            +
                        latents,
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # 6. Prepare extra step kwargs.
         | 
| 237 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    # 7. Denoising loop
         | 
| 240 | 
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 241 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 242 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 243 | 
            +
                            if self.noisy_cond_latents:
         | 
| 244 | 
            +
                                raise ValueError("Noisy condition latents is not recommended.")
         | 
| 245 | 
            +
                            else:
         | 
| 246 | 
            +
                                noisy_cond_latents = cond_latents
         | 
| 247 | 
            +
                            
         | 
| 248 | 
            +
                            noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
         | 
| 249 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 250 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 251 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 252 | 
            +
                                
         | 
| 253 | 
            +
                            # predict the noise residual
         | 
| 254 | 
            +
                            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                            # perform guidance
         | 
| 257 | 
            +
                            if do_classifier_free_guidance:
         | 
| 258 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 259 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 262 | 
            +
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                            # call the callback, if provided
         | 
| 265 | 
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 266 | 
            +
                                progress_bar.update()
         | 
| 267 | 
            +
                                if callback is not None and i % callback_steps == 0:
         | 
| 268 | 
            +
                                    step_idx = i // getattr(self.scheduler, "order", 1)
         | 
| 269 | 
            +
                                    callback(step_idx, t, latents)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if self.latents_offset is not None:
         | 
| 274 | 
            +
                        latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    if not output_type == "latent":
         | 
| 277 | 
            +
                        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
         | 
| 278 | 
            +
                        image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
         | 
| 279 | 
            +
                    else:
         | 
| 280 | 
            +
                        image = latents
         | 
| 281 | 
            +
                        has_nsfw_concept = None
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    if has_nsfw_concept is None:
         | 
| 284 | 
            +
                        do_denormalize = [True] * image.shape[0]
         | 
| 285 | 
            +
                    else:
         | 
| 286 | 
            +
                        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    if not return_dict:
         | 
| 293 | 
            +
                        return (image, has_nsfw_concept)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
         | 
| 296 | 
            +
                
         | 
| 297 | 
            +
            if __name__ == "__main__":
         | 
| 298 | 
            +
                pass
         | 
    	
        custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py
    ADDED
    
    | @@ -0,0 +1,296 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # modified by Wuvin
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple, Union
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
         | 
| 23 | 
            +
            from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
         | 
| 24 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
         | 
| 25 | 
            +
            from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
         | 
| 26 | 
            +
            from PIL import Image
         | 
| 27 | 
            +
            from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class StableDiffusionImage2MVCustomPipeline(
         | 
| 32 | 
            +
                StableDiffusionImageVariationPipeline
         | 
| 33 | 
            +
            ):       
         | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    vae: AutoencoderKL,
         | 
| 37 | 
            +
                    image_encoder: CLIPVisionModelWithProjection,
         | 
| 38 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 39 | 
            +
                    scheduler: KarrasDiffusionSchedulers,
         | 
| 40 | 
            +
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 41 | 
            +
                    feature_extractor: CLIPImageProcessor,
         | 
| 42 | 
            +
                    requires_safety_checker: bool = True,
         | 
| 43 | 
            +
                    latents_offset=None,
         | 
| 44 | 
            +
                    noisy_cond_latents=False,
         | 
| 45 | 
            +
                    condition_offset=True,
         | 
| 46 | 
            +
                ):
         | 
| 47 | 
            +
                    super().__init__(
         | 
| 48 | 
            +
                        vae=vae,
         | 
| 49 | 
            +
                        image_encoder=image_encoder,
         | 
| 50 | 
            +
                        unet=unet,
         | 
| 51 | 
            +
                        scheduler=scheduler,
         | 
| 52 | 
            +
                        safety_checker=safety_checker,
         | 
| 53 | 
            +
                        feature_extractor=feature_extractor,
         | 
| 54 | 
            +
                        requires_safety_checker=requires_safety_checker
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    latents_offset = tuple(latents_offset) if latents_offset is not None else None
         | 
| 57 | 
            +
                    self.latents_offset = latents_offset
         | 
| 58 | 
            +
                    if latents_offset is not None:
         | 
| 59 | 
            +
                        self.register_to_config(latents_offset=latents_offset)
         | 
| 60 | 
            +
                    if noisy_cond_latents:
         | 
| 61 | 
            +
                        raise NotImplementedError("Noisy condition latents not supported Now.")
         | 
| 62 | 
            +
                    self.condition_offset = condition_offset
         | 
| 63 | 
            +
                    self.register_to_config(condition_offset=condition_offset)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def encode_latents(self, image: Image.Image, device, dtype, height, width):
         | 
| 66 | 
            +
                    images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
         | 
| 67 | 
            +
                    # NOTE: .mode() for condition
         | 
| 68 | 
            +
                    latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
         | 
| 69 | 
            +
                    if self.latents_offset is not None and self.condition_offset:
         | 
| 70 | 
            +
                        return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        return latents
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
         | 
| 75 | 
            +
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if not isinstance(image, torch.Tensor):
         | 
| 78 | 
            +
                        image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    image = image.to(device=device, dtype=dtype)
         | 
| 81 | 
            +
                    image_embeddings = self.image_encoder(image).image_embeds
         | 
| 82 | 
            +
                    image_embeddings = image_embeddings.unsqueeze(1)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # duplicate image embeddings for each generation per prompt, using mps friendly method
         | 
| 85 | 
            +
                    bs_embed, seq_len, _ = image_embeddings.shape
         | 
| 86 | 
            +
                    image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 87 | 
            +
                    image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if do_classifier_free_guidance:
         | 
| 90 | 
            +
                        # NOTE: the same as original code
         | 
| 91 | 
            +
                        negative_prompt_embeds = torch.zeros_like(image_embeddings)
         | 
| 92 | 
            +
                        # For classifier free guidance, we need to do two forward passes.
         | 
| 93 | 
            +
                        # Here we concatenate the unconditional and text embeddings into a single batch
         | 
| 94 | 
            +
                        # to avoid doing two forward passes
         | 
| 95 | 
            +
                        image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    return image_embeddings
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                @torch.no_grad()
         | 
| 100 | 
            +
                def __call__(
         | 
| 101 | 
            +
                    self,
         | 
| 102 | 
            +
                    image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
         | 
| 103 | 
            +
                    height: Optional[int] = 1024,
         | 
| 104 | 
            +
                    width: Optional[int] = 1024,
         | 
| 105 | 
            +
                    height_cond: Optional[int] = 512,
         | 
| 106 | 
            +
                    width_cond: Optional[int] = 512,
         | 
| 107 | 
            +
                    num_inference_steps: int = 50,
         | 
| 108 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 109 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 110 | 
            +
                    eta: float = 0.0,
         | 
| 111 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 112 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 113 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 114 | 
            +
                    return_dict: bool = True,
         | 
| 115 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 116 | 
            +
                    callback_steps: int = 1,
         | 
| 117 | 
            +
                ):
         | 
| 118 | 
            +
                    r"""
         | 
| 119 | 
            +
                    The call function to the pipeline for generation.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    Args:
         | 
| 122 | 
            +
                        image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
         | 
| 123 | 
            +
                            Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
         | 
| 124 | 
            +
                            [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
         | 
| 125 | 
            +
                        height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
         | 
| 126 | 
            +
                            The height in pixels of the generated image.
         | 
| 127 | 
            +
                        width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
         | 
| 128 | 
            +
                            The width in pixels of the generated image.
         | 
| 129 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 130 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 131 | 
            +
                            expense of slower inference. This parameter is modulated by `strength`.
         | 
| 132 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 133 | 
            +
                            A higher guidance scale value encourages the model to generate images closely linked to the text
         | 
| 134 | 
            +
                            `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
         | 
| 135 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 136 | 
            +
                            The number of images to generate per prompt.
         | 
| 137 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 138 | 
            +
                            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
         | 
| 139 | 
            +
                            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
         | 
| 140 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 141 | 
            +
                            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
         | 
| 142 | 
            +
                            generation deterministic.
         | 
| 143 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 144 | 
            +
                            Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 145 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 146 | 
            +
                            tensor is generated by sampling using the supplied random `generator`.
         | 
| 147 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 148 | 
            +
                            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
         | 
| 149 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 150 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 151 | 
            +
                            plain tuple.
         | 
| 152 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 153 | 
            +
                            A function that calls every `callback_steps` steps during inference. The function is called with the
         | 
| 154 | 
            +
                            following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 155 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 156 | 
            +
                            The frequency at which the `callback` function is called. If not specified, the callback is called at
         | 
| 157 | 
            +
                            every step.
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    Returns:
         | 
| 160 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 161 | 
            +
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 162 | 
            +
                            otherwise a `tuple` is returned where the first element is a list with the generated images and the
         | 
| 163 | 
            +
                            second element is a list of `bool`s indicating whether the corresponding generated image contains
         | 
| 164 | 
            +
                            "not-safe-for-work" (nsfw) content.
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    Examples:
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    ```py
         | 
| 169 | 
            +
                    from diffusers import StableDiffusionImageVariationPipeline
         | 
| 170 | 
            +
                    from PIL import Image
         | 
| 171 | 
            +
                    from io import BytesIO
         | 
| 172 | 
            +
                    import requests
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    pipe = StableDiffusionImageVariationPipeline.from_pretrained(
         | 
| 175 | 
            +
                        "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
         | 
| 176 | 
            +
                    )
         | 
| 177 | 
            +
                    pipe = pipe.to("cuda")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    response = requests.get(url)
         | 
| 182 | 
            +
                    image = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
         | 
| 185 | 
            +
                    out["images"][0].save("result.jpg")
         | 
| 186 | 
            +
                    ```
         | 
| 187 | 
            +
                    """
         | 
| 188 | 
            +
                    # 0. Default height and width to unet
         | 
| 189 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 190 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 193 | 
            +
                    self.check_inputs(image, height, width, callback_steps)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # 2. Define call parameters
         | 
| 196 | 
            +
                    if isinstance(image, Image.Image):
         | 
| 197 | 
            +
                        batch_size = 1
         | 
| 198 | 
            +
                    elif len(image) == 1:
         | 
| 199 | 
            +
                        image = image[0]
         | 
| 200 | 
            +
                        batch_size = 1
         | 
| 201 | 
            +
                    else:
         | 
| 202 | 
            +
                        raise NotImplementedError()
         | 
| 203 | 
            +
                    # elif isinstance(image, list):
         | 
| 204 | 
            +
                    #     batch_size = len(image)
         | 
| 205 | 
            +
                    # else:
         | 
| 206 | 
            +
                    #     batch_size = image.shape[0]
         | 
| 207 | 
            +
                    device = self._execution_device
         | 
| 208 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 209 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 210 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 211 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # 3. Encode input image
         | 
| 214 | 
            +
                    emb_image = image
         | 
| 215 | 
            +
                    
         | 
| 216 | 
            +
                    image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
         | 
| 217 | 
            +
                    cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
         | 
| 218 | 
            +
                    cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
         | 
| 219 | 
            +
                    image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
         | 
| 220 | 
            +
                    if do_classifier_free_guidance:
         | 
| 221 | 
            +
                        image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # 4. Prepare timesteps
         | 
| 224 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 225 | 
            +
                    timesteps = self.scheduler.timesteps
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # 5. Prepare latent variables
         | 
| 228 | 
            +
                    num_channels_latents = self.unet.config.out_channels
         | 
| 229 | 
            +
                    latents = self.prepare_latents(
         | 
| 230 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 231 | 
            +
                        num_channels_latents,
         | 
| 232 | 
            +
                        height,
         | 
| 233 | 
            +
                        width,
         | 
| 234 | 
            +
                        image_embeddings.dtype,
         | 
| 235 | 
            +
                        device,
         | 
| 236 | 
            +
                        generator,
         | 
| 237 | 
            +
                        latents,
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
                    # 6. Prepare extra step kwargs.
         | 
| 242 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 243 | 
            +
                    # 7. Denoising loop
         | 
| 244 | 
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 245 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 246 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 247 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 248 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 249 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 250 | 
            +
                            
         | 
| 251 | 
            +
                            # predict the noise residual
         | 
| 252 | 
            +
                            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                            # perform guidance
         | 
| 255 | 
            +
                            if do_classifier_free_guidance:
         | 
| 256 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 257 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 260 | 
            +
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                            # call the callback, if provided
         | 
| 263 | 
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 264 | 
            +
                                progress_bar.update()
         | 
| 265 | 
            +
                                if callback is not None and i % callback_steps == 0:
         | 
| 266 | 
            +
                                    step_idx = i // getattr(self.scheduler, "order", 1)
         | 
| 267 | 
            +
                                    callback(step_idx, t, latents)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    if self.latents_offset is not None:
         | 
| 272 | 
            +
                        latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    if not output_type == "latent":
         | 
| 275 | 
            +
                        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
         | 
| 276 | 
            +
                        image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
         | 
| 277 | 
            +
                    else:
         | 
| 278 | 
            +
                        image = latents
         | 
| 279 | 
            +
                        has_nsfw_concept = None
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    if has_nsfw_concept is None:
         | 
| 282 | 
            +
                        do_denormalize = [True] * image.shape[0]
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    if not return_dict:
         | 
| 291 | 
            +
                        return (image, has_nsfw_concept)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
         | 
| 294 | 
            +
                
         | 
| 295 | 
            +
            if __name__ == "__main__":
         | 
| 296 | 
            +
                pass
         | 
    	
        custum_3d_diffusion/modules.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __modules__ = {}
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            def register(name):
         | 
| 4 | 
            +
                def decorator(cls):
         | 
| 5 | 
            +
                    __modules__[name] = cls
         | 
| 6 | 
            +
                    return cls
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                return decorator
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def find(name):
         | 
| 12 | 
            +
                return __modules__[name]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
         | 
    	
        custum_3d_diffusion/trainings/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        custum_3d_diffusion/trainings/base.py
    ADDED
    
    | @@ -0,0 +1,208 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from accelerate import Accelerator
         | 
| 3 | 
            +
            from accelerate.logging import MultiProcessAdapter
         | 
| 4 | 
            +
            from dataclasses import dataclass, field
         | 
| 5 | 
            +
            from typing import Optional, Union
         | 
| 6 | 
            +
            from datasets import load_dataset
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import abc
         | 
| 9 | 
            +
            from diffusers.utils import make_image_grid
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import wandb
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from custum_3d_diffusion.trainings.utils import load_config
         | 
| 14 | 
            +
            from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            class BasicTrainer(torch.nn.Module, abc.ABC):
         | 
| 17 | 
            +
                accelerator: Accelerator
         | 
| 18 | 
            +
                logger: MultiProcessAdapter
         | 
| 19 | 
            +
                unet: ConfigurableUNet2DConditionModel
         | 
| 20 | 
            +
                train_dataloader: torch.utils.data.DataLoader
         | 
| 21 | 
            +
                test_dataset: torch.utils.data.Dataset
         | 
| 22 | 
            +
                attn_config: AttnConfig
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                @dataclass
         | 
| 25 | 
            +
                class TrainerConfig:
         | 
| 26 | 
            +
                    trainer_name: str = "basic"
         | 
| 27 | 
            +
                    pretrained_model_name_or_path: str = ""
         | 
| 28 | 
            +
                    
         | 
| 29 | 
            +
                    attn_config: dict = field(default_factory=dict)
         | 
| 30 | 
            +
                    dataset_name: str = ""
         | 
| 31 | 
            +
                    dataset_config_name: Optional[str] = None
         | 
| 32 | 
            +
                    resolution: str = "1024"
         | 
| 33 | 
            +
                    dataloader_num_workers: int = 4
         | 
| 34 | 
            +
                    pair_sampler_group_size: int = 1
         | 
| 35 | 
            +
                    num_views: int = 4
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    max_train_steps: int = -1                       # -1 means infinity, otherwise [0, max_train_steps)
         | 
| 38 | 
            +
                    training_step_interval: int = 1                 # train on step i*interval, stop at max_train_steps
         | 
| 39 | 
            +
                    max_train_samples: Optional[int] = None
         | 
| 40 | 
            +
                    seed: Optional[int] = None                      # For dataset related operations and validation stuff
         | 
| 41 | 
            +
                    train_batch_size: int = 1
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                    validation_interval: int = 5000
         | 
| 44 | 
            +
                    debug: bool = False
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                cfg: TrainerConfig    # only enable_xxx is used
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self, 
         | 
| 50 | 
            +
                    accelerator: Accelerator, 
         | 
| 51 | 
            +
                    logger: MultiProcessAdapter,
         | 
| 52 | 
            +
                    unet: ConfigurableUNet2DConditionModel,
         | 
| 53 | 
            +
                    config: Union[dict, str],
         | 
| 54 | 
            +
                    weight_dtype: torch.dtype,
         | 
| 55 | 
            +
                    index: int,
         | 
| 56 | 
            +
                ):
         | 
| 57 | 
            +
                    super().__init__()
         | 
| 58 | 
            +
                    self.index = index              # index in all trainers
         | 
| 59 | 
            +
                    self.accelerator = accelerator
         | 
| 60 | 
            +
                    self.logger = logger
         | 
| 61 | 
            +
                    self.unet = unet
         | 
| 62 | 
            +
                    self.weight_dtype = weight_dtype
         | 
| 63 | 
            +
                    self.ext_logs = {}
         | 
| 64 | 
            +
                    self.cfg = load_config(self.TrainerConfig, config)
         | 
| 65 | 
            +
                    self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
         | 
| 66 | 
            +
                    self.test_dataset = None
         | 
| 67 | 
            +
                    self.validate_trainer_config()
         | 
| 68 | 
            +
                    self.configure()
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                def get_HW(self):
         | 
| 71 | 
            +
                    resolution = json.loads(self.cfg.resolution)
         | 
| 72 | 
            +
                    if isinstance(resolution, int):
         | 
| 73 | 
            +
                        H = W = resolution
         | 
| 74 | 
            +
                    elif isinstance(resolution, list):
         | 
| 75 | 
            +
                        H, W = resolution
         | 
| 76 | 
            +
                    return H, W
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                def unet_update(self):
         | 
| 79 | 
            +
                    self.unet.update_config(self.attn_config)
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                def validate_trainer_config(self):
         | 
| 82 | 
            +
                    pass
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                def is_train_finished(self, current_step):
         | 
| 85 | 
            +
                    assert isinstance(self.cfg.max_train_steps, int)
         | 
| 86 | 
            +
                    return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                def next_train_step(self, current_step):
         | 
| 89 | 
            +
                    if self.is_train_finished(current_step):
         | 
| 90 | 
            +
                        return None
         | 
| 91 | 
            +
                    return current_step + self.cfg.training_step_interval
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                @classmethod
         | 
| 94 | 
            +
                def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
         | 
| 95 | 
            +
                    catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
         | 
| 96 | 
            +
                    return make_image_grid(catted, rows=1, cols=len(catted))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def configure(self) -> None:
         | 
| 99 | 
            +
                    pass
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                @abc.abstractmethod
         | 
| 102 | 
            +
                def init_shared_modules(self, shared_modules: dict) -> dict:
         | 
| 103 | 
            +
                    pass
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                def load_dataset(self):
         | 
| 106 | 
            +
                    dataset = load_dataset(
         | 
| 107 | 
            +
                        self.cfg.dataset_name,
         | 
| 108 | 
            +
                        self.cfg.dataset_config_name,
         | 
| 109 | 
            +
                        trust_remote_code=True
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
                    return dataset
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                @abc.abstractmethod
         | 
| 114 | 
            +
                def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
         | 
| 115 | 
            +
                    """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
         | 
| 116 | 
            +
                    pass
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                @abc.abstractmethod
         | 
| 119 | 
            +
                def forward_step(
         | 
| 120 | 
            +
                    self, 
         | 
| 121 | 
            +
                    *args, 
         | 
| 122 | 
            +
                    **kwargs
         | 
| 123 | 
            +
                ) -> torch.Tensor:
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    input a batch
         | 
| 126 | 
            +
                    return a loss
         | 
| 127 | 
            +
                    """
         | 
| 128 | 
            +
                    self.unet_update()
         | 
| 129 | 
            +
                    pass
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                @abc.abstractmethod
         | 
| 132 | 
            +
                def construct_pipeline(self, shared_modules, unet):
         | 
| 133 | 
            +
                    pass
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                @abc.abstractmethod
         | 
| 136 | 
            +
                def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                        For inference time forward.
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    pass
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                @abc.abstractmethod
         | 
| 143 | 
            +
                def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
         | 
| 144 | 
            +
                    pass
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def do_validation(
         | 
| 147 | 
            +
                    self,
         | 
| 148 | 
            +
                    shared_modules,
         | 
| 149 | 
            +
                    unet,
         | 
| 150 | 
            +
                    global_step,
         | 
| 151 | 
            +
                ):
         | 
| 152 | 
            +
                    self.unet_update()
         | 
| 153 | 
            +
                    self.logger.info("Running validation... ")
         | 
| 154 | 
            +
                    pipeline = self.construct_pipeline(shared_modules, unet)
         | 
| 155 | 
            +
                    pipeline.set_progress_bar_config(disable=True)
         | 
| 156 | 
            +
                    titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
         | 
| 157 | 
            +
                    for tracker in self.accelerator.trackers:
         | 
| 158 | 
            +
                        if tracker.name == "tensorboard":
         | 
| 159 | 
            +
                            np_images = np.stack([np.asarray(img) for img in images])
         | 
| 160 | 
            +
                            tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
         | 
| 161 | 
            +
                        elif tracker.name == "wandb":
         | 
| 162 | 
            +
                            [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title]   # inplace operation
         | 
| 163 | 
            +
                            tracker.log({"validation": [
         | 
| 164 | 
            +
                                wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
         | 
| 165 | 
            +
                                for i, image in enumerate(images)]})
         | 
| 166 | 
            +
                        else:
         | 
| 167 | 
            +
                            self.logger.warn(f"image logging not implemented for {tracker.name}")
         | 
| 168 | 
            +
                    del pipeline
         | 
| 169 | 
            +
                    torch.cuda.empty_cache()
         | 
| 170 | 
            +
                    return images
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                @torch.no_grad()
         | 
| 174 | 
            +
                def log_validation(
         | 
| 175 | 
            +
                    self,
         | 
| 176 | 
            +
                    shared_modules,
         | 
| 177 | 
            +
                    unet,
         | 
| 178 | 
            +
                    global_step,
         | 
| 179 | 
            +
                    force=False
         | 
| 180 | 
            +
                ):
         | 
| 181 | 
            +
                    if self.accelerator.is_main_process:
         | 
| 182 | 
            +
                        for tracker in self.accelerator.trackers:
         | 
| 183 | 
            +
                            if tracker.name == "wandb":
         | 
| 184 | 
            +
                                tracker.log(self.ext_logs)
         | 
| 185 | 
            +
                    self.ext_logs = {}
         | 
| 186 | 
            +
                    if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
         | 
| 187 | 
            +
                        self.unet_update()
         | 
| 188 | 
            +
                        if self.accelerator.is_main_process:
         | 
| 189 | 
            +
                            self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def save_model(self, unwrap_unet, shared_modules, save_dir):
         | 
| 192 | 
            +
                    if self.accelerator.is_main_process:
         | 
| 193 | 
            +
                        pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
         | 
| 194 | 
            +
                        pipeline.save_pretrained(save_dir)
         | 
| 195 | 
            +
                        self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def save_debug_info(self, save_name="debug", **kwargs):
         | 
| 198 | 
            +
                    if self.cfg.debug:
         | 
| 199 | 
            +
                        to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
         | 
| 200 | 
            +
                        import pickle
         | 
| 201 | 
            +
                        import os
         | 
| 202 | 
            +
                        if os.path.exists(f"{save_name}.pkl"):
         | 
| 203 | 
            +
                            for i in range(100):
         | 
| 204 | 
            +
                                if not os.path.exists(f"{save_name}_v{i}.pkl"):
         | 
| 205 | 
            +
                                    save_name = f"{save_name}_v{i}"
         | 
| 206 | 
            +
                                    break
         | 
| 207 | 
            +
                        with open(f"{save_name}.pkl", "wb") as f:
         | 
| 208 | 
            +
                            pickle.dump(to_saves, f)
         | 
    	
        custum_3d_diffusion/trainings/config_classes.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import List, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            @dataclass
         | 
| 6 | 
            +
            class TrainerSubConfig:
         | 
| 7 | 
            +
                trainer_type: str = ""
         | 
| 8 | 
            +
                trainer: dict = field(default_factory=dict)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @dataclass
         | 
| 12 | 
            +
            class ExprimentConfig:
         | 
| 13 | 
            +
                trainers: List[dict] = field(default_factory=lambda: [])
         | 
| 14 | 
            +
                init_config: dict = field(default_factory=dict)
         | 
| 15 | 
            +
                pretrained_model_name_or_path: str = ""
         | 
| 16 | 
            +
                pretrained_unet_state_dict_path: str = ""
         | 
| 17 | 
            +
                # expriments related parameters
         | 
| 18 | 
            +
                linear_beta_schedule: bool = False
         | 
| 19 | 
            +
                zero_snr: bool = False
         | 
| 20 | 
            +
                prediction_type: Optional[str] = None
         | 
| 21 | 
            +
                seed: Optional[int] = None
         | 
| 22 | 
            +
                max_train_steps: int = 1000000
         | 
| 23 | 
            +
                gradient_accumulation_steps: int = 1
         | 
| 24 | 
            +
                learning_rate: float = 1e-4
         | 
| 25 | 
            +
                lr_scheduler: str = "constant"
         | 
| 26 | 
            +
                lr_warmup_steps: int = 500
         | 
| 27 | 
            +
                use_8bit_adam: bool = False
         | 
| 28 | 
            +
                adam_beta1: float = 0.9
         | 
| 29 | 
            +
                adam_beta2: float = 0.999
         | 
| 30 | 
            +
                adam_weight_decay: float = 1e-2
         | 
| 31 | 
            +
                adam_epsilon: float = 1e-08
         | 
| 32 | 
            +
                max_grad_norm: float = 1.0
         | 
| 33 | 
            +
                mixed_precision: Optional[str] = None       # ["no", "fp16", "bf16", "fp8"]
         | 
| 34 | 
            +
                skip_training: bool = False
         | 
| 35 | 
            +
                debug: bool = False
         | 
    	
        custum_3d_diffusion/trainings/image2image_trainer.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from custum_3d_diffusion.modules import register
         | 
| 7 | 
            +
            from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
         | 
| 8 | 
            +
            from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
         | 
| 9 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_HW(resolution):
         | 
| 12 | 
            +
                if isinstance(resolution, str):
         | 
| 13 | 
            +
                    resolution = json.loads(resolution)
         | 
| 14 | 
            +
                if isinstance(resolution, int):
         | 
| 15 | 
            +
                    H = W = resolution
         | 
| 16 | 
            +
                elif isinstance(resolution, list):
         | 
| 17 | 
            +
                    H, W = resolution
         | 
| 18 | 
            +
                return H, W
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            @register("image2image_trainer")
         | 
| 22 | 
            +
            class Image2ImageTrainer(Image2MVImageTrainer):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Trainer for simple image to multiview images.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                @dataclass
         | 
| 27 | 
            +
                class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
         | 
| 28 | 
            +
                    trainer_name: str = "image2image"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                cfg: TrainerConfig
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
         | 
| 33 | 
            +
                    raise NotImplementedError()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def construct_pipeline(self, shared_modules, unet, old_version=False):
         | 
| 36 | 
            +
                    MyPipeline = StableDiffusionImageCustomPipeline
         | 
| 37 | 
            +
                    pipeline = MyPipeline.from_pretrained(
         | 
| 38 | 
            +
                        self.cfg.pretrained_model_name_or_path,
         | 
| 39 | 
            +
                        vae=shared_modules['vae'],
         | 
| 40 | 
            +
                        image_encoder=shared_modules['image_encoder'],
         | 
| 41 | 
            +
                        feature_extractor=shared_modules['feature_extractor'],
         | 
| 42 | 
            +
                        unet=unet,
         | 
| 43 | 
            +
                        safety_checker=None,
         | 
| 44 | 
            +
                        torch_dtype=self.weight_dtype,
         | 
| 45 | 
            +
                        latents_offset=self.cfg.latents_offset,
         | 
| 46 | 
            +
                        noisy_cond_latents=self.cfg.noisy_condition_input,
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    pipeline.set_progress_bar_config(disable=True)
         | 
| 49 | 
            +
                    scheduler_dict = {}
         | 
| 50 | 
            +
                    if self.cfg.zero_snr:
         | 
| 51 | 
            +
                        scheduler_dict.update(rescale_betas_zero_snr=True)
         | 
| 52 | 
            +
                    if self.cfg.linear_beta_schedule:
         | 
| 53 | 
            +
                        scheduler_dict.update(beta_schedule='linear')
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
         | 
| 56 | 
            +
                    return pipeline
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def get_forward_args(self):
         | 
| 59 | 
            +
                    if self.cfg.seed is None:
         | 
| 60 | 
            +
                        generator = None
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    H, W = get_HW(self.cfg.resolution)
         | 
| 65 | 
            +
                    H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    forward_args = dict(
         | 
| 68 | 
            +
                        num_images_per_prompt=1,
         | 
| 69 | 
            +
                        num_inference_steps=20,
         | 
| 70 | 
            +
                        height=H,
         | 
| 71 | 
            +
                        width=W,
         | 
| 72 | 
            +
                        height_cond=H_cond,
         | 
| 73 | 
            +
                        width_cond=W_cond,
         | 
| 74 | 
            +
                        generator=generator,
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    if self.cfg.zero_snr:
         | 
| 77 | 
            +
                        forward_args.update(guidance_rescale=0.7)
         | 
| 78 | 
            +
                    return forward_args
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
         | 
| 81 | 
            +
                    forward_args = self.get_forward_args()
         | 
| 82 | 
            +
                    forward_args.update(pipeline_call_kwargs)
         | 
| 83 | 
            +
                    return pipeline(**forward_args)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
         | 
| 86 | 
            +
                    raise NotImplementedError()
         | 
    	
        custum_3d_diffusion/trainings/image2mvimage_trainer.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
         | 
| 3 | 
            +
            from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            from dataclasses import dataclass
         | 
| 7 | 
            +
            from typing import List, Optional
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from custum_3d_diffusion.modules import register
         | 
| 10 | 
            +
            from custum_3d_diffusion.trainings.base import BasicTrainer
         | 
| 11 | 
            +
            from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
         | 
| 12 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def get_HW(resolution):
         | 
| 15 | 
            +
                if isinstance(resolution, str):
         | 
| 16 | 
            +
                    resolution = json.loads(resolution)
         | 
| 17 | 
            +
                if isinstance(resolution, int):
         | 
| 18 | 
            +
                    H = W = resolution
         | 
| 19 | 
            +
                elif isinstance(resolution, list):
         | 
| 20 | 
            +
                    H, W = resolution
         | 
| 21 | 
            +
                return H, W
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            @register("image2mvimage_trainer")
         | 
| 24 | 
            +
            class Image2MVImageTrainer(BasicTrainer):
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                Trainer for simple image to multiview images.
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                @dataclass
         | 
| 29 | 
            +
                class TrainerConfig(BasicTrainer.TrainerConfig):
         | 
| 30 | 
            +
                    trainer_name: str = "image2mvimage"
         | 
| 31 | 
            +
                    condition_image_column_name: str = "conditioning_image"
         | 
| 32 | 
            +
                    image_column_name: str = "image"
         | 
| 33 | 
            +
                    condition_dropout: float = 0.
         | 
| 34 | 
            +
                    condition_image_resolution: str = "512"
         | 
| 35 | 
            +
                    validation_images: Optional[List[str]] = None
         | 
| 36 | 
            +
                    noise_offset: float = 0.1                           
         | 
| 37 | 
            +
                    max_loss_drop: float = 0.                           
         | 
| 38 | 
            +
                    snr_gamma: float = 5.0                              
         | 
| 39 | 
            +
                    log_distribution: bool = False
         | 
| 40 | 
            +
                    latents_offset: Optional[List[float]] = None
         | 
| 41 | 
            +
                    input_perturbation: float = 0.
         | 
| 42 | 
            +
                    noisy_condition_input: bool = False                 # whether to add noise for ref unet input
         | 
| 43 | 
            +
                    normal_cls_offset: int = 0
         | 
| 44 | 
            +
                    condition_offset: bool = True
         | 
| 45 | 
            +
                    zero_snr: bool = False
         | 
| 46 | 
            +
                    linear_beta_schedule: bool = False
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                cfg: TrainerConfig
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def configure(self) -> None:
         | 
| 51 | 
            +
                    return super().configure()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def init_shared_modules(self, shared_modules: dict) -> dict:
         | 
| 54 | 
            +
                    if 'vae' not in shared_modules:
         | 
| 55 | 
            +
                        vae = AutoencoderKL.from_pretrained(
         | 
| 56 | 
            +
                            self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                        vae.requires_grad_(False)
         | 
| 59 | 
            +
                        vae.to(self.accelerator.device, dtype=self.weight_dtype)
         | 
| 60 | 
            +
                        shared_modules['vae'] = vae
         | 
| 61 | 
            +
                    if 'image_encoder' not in shared_modules:
         | 
| 62 | 
            +
                        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
         | 
| 63 | 
            +
                            self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
                        image_encoder.requires_grad_(False)
         | 
| 66 | 
            +
                        image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
         | 
| 67 | 
            +
                        shared_modules['image_encoder'] = image_encoder
         | 
| 68 | 
            +
                    if 'feature_extractor' not in shared_modules:
         | 
| 69 | 
            +
                        feature_extractor = CLIPImageProcessor.from_pretrained(
         | 
| 70 | 
            +
                            self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                        shared_modules['feature_extractor'] = feature_extractor
         | 
| 73 | 
            +
                    return shared_modules
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
         | 
| 76 | 
            +
                    raise NotImplementedError()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def loss_rescale(self, loss, timesteps=None):
         | 
| 79 | 
            +
                    raise NotImplementedError()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
         | 
| 82 | 
            +
                    raise NotImplementedError()
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                def construct_pipeline(self, shared_modules, unet, old_version=False):
         | 
| 85 | 
            +
                    MyPipeline = StableDiffusionImage2MVCustomPipeline
         | 
| 86 | 
            +
                    pipeline = MyPipeline.from_pretrained(
         | 
| 87 | 
            +
                        self.cfg.pretrained_model_name_or_path,
         | 
| 88 | 
            +
                        vae=shared_modules['vae'],
         | 
| 89 | 
            +
                        image_encoder=shared_modules['image_encoder'],
         | 
| 90 | 
            +
                        feature_extractor=shared_modules['feature_extractor'],
         | 
| 91 | 
            +
                        unet=unet,
         | 
| 92 | 
            +
                        safety_checker=None,
         | 
| 93 | 
            +
                        torch_dtype=self.weight_dtype,
         | 
| 94 | 
            +
                        latents_offset=self.cfg.latents_offset,
         | 
| 95 | 
            +
                        noisy_cond_latents=self.cfg.noisy_condition_input,
         | 
| 96 | 
            +
                        condition_offset=self.cfg.condition_offset,
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    pipeline.set_progress_bar_config(disable=True)
         | 
| 99 | 
            +
                    scheduler_dict = {}
         | 
| 100 | 
            +
                    if self.cfg.zero_snr:
         | 
| 101 | 
            +
                        scheduler_dict.update(rescale_betas_zero_snr=True)
         | 
| 102 | 
            +
                    if self.cfg.linear_beta_schedule:
         | 
| 103 | 
            +
                        scheduler_dict.update(beta_schedule='linear')
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
         | 
| 106 | 
            +
                    return pipeline
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def get_forward_args(self):
         | 
| 109 | 
            +
                    if self.cfg.seed is None:
         | 
| 110 | 
            +
                        generator = None
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    H, W = get_HW(self.cfg.resolution)
         | 
| 115 | 
            +
                    H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    sub_img_H = H // 2
         | 
| 118 | 
            +
                    num_imgs = H // sub_img_H * W // sub_img_H
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    forward_args = dict(
         | 
| 121 | 
            +
                        num_images_per_prompt=num_imgs,
         | 
| 122 | 
            +
                        num_inference_steps=50,
         | 
| 123 | 
            +
                        height=sub_img_H,
         | 
| 124 | 
            +
                        width=sub_img_H,
         | 
| 125 | 
            +
                        height_cond=H_cond,
         | 
| 126 | 
            +
                        width_cond=W_cond,
         | 
| 127 | 
            +
                        generator=generator,
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                    if self.cfg.zero_snr:
         | 
| 130 | 
            +
                        forward_args.update(guidance_rescale=0.7)
         | 
| 131 | 
            +
                    return forward_args
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
         | 
| 134 | 
            +
                    forward_args = self.get_forward_args()
         | 
| 135 | 
            +
                    forward_args.update(pipeline_call_kwargs)
         | 
| 136 | 
            +
                    return pipeline(**forward_args)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
         | 
| 139 | 
            +
                    raise NotImplementedError()
         | 
    	
        custum_3d_diffusion/trainings/utils.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from omegaconf import DictConfig, OmegaConf
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def parse_structured(fields, cfg) -> DictConfig:
         | 
| 5 | 
            +
                scfg = OmegaConf.structured(fields(**cfg))
         | 
| 6 | 
            +
                return scfg
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def load_config(fields, config, extras=None):
         | 
| 10 | 
            +
                if extras is not None:
         | 
| 11 | 
            +
                    print("Warning! extra parameter in cli is not verified, may cause erros.")
         | 
| 12 | 
            +
                if isinstance(config, str):
         | 
| 13 | 
            +
                    cfg = OmegaConf.load(config)
         | 
| 14 | 
            +
                elif isinstance(config, dict):
         | 
| 15 | 
            +
                    cfg = OmegaConf.create(config)
         | 
| 16 | 
            +
                elif isinstance(config, DictConfig):
         | 
| 17 | 
            +
                    cfg = config
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    raise NotImplementedError(f"Unsupported config type {type(config)}")
         | 
| 20 | 
            +
                if extras is not None:
         | 
| 21 | 
            +
                    cli_conf = OmegaConf.from_cli(extras)
         | 
| 22 | 
            +
                    cfg = OmegaConf.merge(cfg, cli_conf)
         | 
| 23 | 
            +
                OmegaConf.resolve(cfg)
         | 
| 24 | 
            +
                assert isinstance(cfg, DictConfig)
         | 
| 25 | 
            +
                return parse_structured(fields, cfg)
         | 
    	
        gradio_app/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        gradio_app/all_models.py
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from scripts.sd_model_zoo import load_common_sd15_pipe
         | 
| 3 | 
            +
            from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class MyModelZoo:
         | 
| 7 | 
            +
                _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                base_model = "benjamin-paine/stable-diffusion-v1-5"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __init__(self, base_model=None) -> None:
         | 
| 12 | 
            +
                    if base_model is not None:
         | 
| 13 | 
            +
                        self.base_model = base_model
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                @property
         | 
| 16 | 
            +
                def pipe_disney_controlnet_tile_ipadapter_i2i(self):
         | 
| 17 | 
            +
                    return self._pipe_disney_controlnet_lineart_ipadapter_i2i
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                def init_models(self):
         | 
| 20 | 
            +
                    self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            model_zoo = MyModelZoo()
         | 
    	
        gradio_app/custom_models/image2mvimage.yaml
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pretrained_model_name_or_path: "./ckpt/img2mvimg"
         | 
| 2 | 
            +
            mixed_precision: "bf16"
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            init_config: 
         | 
| 5 | 
            +
              # enable controls
         | 
| 6 | 
            +
              enable_cross_attn_lora: False
         | 
| 7 | 
            +
              enable_cross_attn_ip: False
         | 
| 8 | 
            +
              enable_self_attn_lora: False
         | 
| 9 | 
            +
              enable_self_attn_ref: False
         | 
| 10 | 
            +
              enable_multiview_attn: True
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              # for cross attention
         | 
| 13 | 
            +
              init_cross_attn_lora: False
         | 
| 14 | 
            +
              init_cross_attn_ip: False
         | 
| 15 | 
            +
              cross_attn_lora_rank: 256        # 0 for not enabled
         | 
| 16 | 
            +
              cross_attn_lora_only_kv: False
         | 
| 17 | 
            +
              ipadapter_pretrained_name: "h94/IP-Adapter"
         | 
| 18 | 
            +
              ipadapter_subfolder_name: "models"
         | 
| 19 | 
            +
              ipadapter_weight_name: "ip-adapter_sd15.safetensors"
         | 
| 20 | 
            +
              ipadapter_effect_on: "all"    # all, first
         | 
| 21 | 
            +
             | 
| 22 | 
            +
              # for self attention
         | 
| 23 | 
            +
              init_self_attn_lora: False
         | 
| 24 | 
            +
              self_attn_lora_rank: 256
         | 
| 25 | 
            +
              self_attn_lora_only_kv: False
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              # for self attention ref
         | 
| 28 | 
            +
              init_self_attn_ref: False
         | 
| 29 | 
            +
              self_attn_ref_position: "attn1"
         | 
| 30 | 
            +
              self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
         | 
| 31 | 
            +
              self_attn_ref_pixel_wise_crosspond: False
         | 
| 32 | 
            +
              self_attn_ref_effect_on: "all"
         | 
| 33 | 
            +
              
         | 
| 34 | 
            +
              # for multiview attention
         | 
| 35 | 
            +
              init_multiview_attn: True
         | 
| 36 | 
            +
              multiview_attn_position: "attn1"
         | 
| 37 | 
            +
              use_mv_joint_attn: True
         | 
| 38 | 
            +
              num_modalities: 1
         | 
| 39 | 
            +
              
         | 
| 40 | 
            +
              # for unet
         | 
| 41 | 
            +
              init_unet_path: "${pretrained_model_name_or_path}"
         | 
| 42 | 
            +
              cat_condition: True       # cat condition to input
         | 
| 43 | 
            +
              
         | 
| 44 | 
            +
              # for cls embedding
         | 
| 45 | 
            +
              init_num_cls_label: 8     # for initialize
         | 
| 46 | 
            +
              cls_labels: [0, 1, 2, 3]  # for current task
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            trainers:
         | 
| 49 | 
            +
              - trainer_type: "image2mvimage_trainer"
         | 
| 50 | 
            +
                trainer:
         | 
| 51 | 
            +
                    pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
         | 
| 52 | 
            +
                    attn_config:
         | 
| 53 | 
            +
                      cls_labels: [0, 1, 2, 3]  # for current task
         | 
| 54 | 
            +
                      enable_cross_attn_lora: False
         | 
| 55 | 
            +
                      enable_cross_attn_ip: False
         | 
| 56 | 
            +
                      enable_self_attn_lora: False
         | 
| 57 | 
            +
                      enable_self_attn_ref: False
         | 
| 58 | 
            +
                      enable_multiview_attn: True
         | 
| 59 | 
            +
                    resolution: "256"
         | 
| 60 | 
            +
                    condition_image_resolution: "256"
         | 
| 61 | 
            +
                    normal_cls_offset: 4
         | 
| 62 | 
            +
                    condition_image_column_name: "conditioning_image"
         | 
| 63 | 
            +
                    image_column_name: "image"
         | 
    	
        gradio_app/custom_models/image2normal.yaml
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
         | 
| 2 | 
            +
            mixed_precision: "bf16"
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            init_config: 
         | 
| 5 | 
            +
              # enable controls
         | 
| 6 | 
            +
              enable_cross_attn_lora: False
         | 
| 7 | 
            +
              enable_cross_attn_ip: False
         | 
| 8 | 
            +
              enable_self_attn_lora: False
         | 
| 9 | 
            +
              enable_self_attn_ref: True
         | 
| 10 | 
            +
              enable_multiview_attn: False
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              # for cross attention
         | 
| 13 | 
            +
              init_cross_attn_lora: False
         | 
| 14 | 
            +
              init_cross_attn_ip: False
         | 
| 15 | 
            +
              cross_attn_lora_rank: 512        # 0 for not enabled
         | 
| 16 | 
            +
              cross_attn_lora_only_kv: False
         | 
| 17 | 
            +
              ipadapter_pretrained_name: "h94/IP-Adapter"
         | 
| 18 | 
            +
              ipadapter_subfolder_name: "models"
         | 
| 19 | 
            +
              ipadapter_weight_name: "ip-adapter_sd15.safetensors"
         | 
| 20 | 
            +
              ipadapter_effect_on: "all"    # all, first
         | 
| 21 | 
            +
             | 
| 22 | 
            +
              # for self attention
         | 
| 23 | 
            +
              init_self_attn_lora: False
         | 
| 24 | 
            +
              self_attn_lora_rank: 512
         | 
| 25 | 
            +
              self_attn_lora_only_kv: False
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              # for self attention ref
         | 
| 28 | 
            +
              init_self_attn_ref: True
         | 
| 29 | 
            +
              self_attn_ref_position: "attn1"
         | 
| 30 | 
            +
              self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
         | 
| 31 | 
            +
              self_attn_ref_pixel_wise_crosspond: True
         | 
| 32 | 
            +
              self_attn_ref_effect_on: "all"
         | 
| 33 | 
            +
              
         | 
| 34 | 
            +
              # for multiview attention
         | 
| 35 | 
            +
              init_multiview_attn: False
         | 
| 36 | 
            +
              multiview_attn_position: "attn1"
         | 
| 37 | 
            +
              num_modalities: 1
         | 
| 38 | 
            +
              
         | 
| 39 | 
            +
              # for unet
         | 
| 40 | 
            +
              init_unet_path: "${pretrained_model_name_or_path}"
         | 
| 41 | 
            +
              init_num_cls_label: 0     # for initialize
         | 
| 42 | 
            +
              cls_labels: []  # for current task
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            trainers: 
         | 
| 45 | 
            +
              - trainer_type: "image2image_trainer"
         | 
| 46 | 
            +
                trainer:
         | 
| 47 | 
            +
                    pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
         | 
| 48 | 
            +
                    attn_config:
         | 
| 49 | 
            +
                      cls_labels: []  # for current task
         | 
| 50 | 
            +
                      enable_cross_attn_lora: False
         | 
| 51 | 
            +
                      enable_cross_attn_ip: False
         | 
| 52 | 
            +
                      enable_self_attn_lora: False
         | 
| 53 | 
            +
                      enable_self_attn_ref: True
         | 
| 54 | 
            +
                      enable_multiview_attn: False
         | 
| 55 | 
            +
                    resolution: "512"
         | 
| 56 | 
            +
                    condition_image_resolution: "512"
         | 
| 57 | 
            +
                    condition_image_column_name: "conditioning_image"
         | 
| 58 | 
            +
                    image_column_name: "image"
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
    	
        gradio_app/custom_models/mvimg_prediction.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from rembg import remove
         | 
| 7 | 
            +
            from gradio_app.utils import change_rgba_bg, rgba_to_rgb
         | 
| 8 | 
            +
            from gradio_app.custom_models.utils import load_pipeline
         | 
| 9 | 
            +
            from scripts.all_typing import *
         | 
| 10 | 
            +
            from scripts.utils import session, simple_preprocess
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            training_config = "gradio_app/custom_models/image2mvimage.yaml"
         | 
| 13 | 
            +
            checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            trainer, pipeline = load_pipeline(training_config, checkpoint_path)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
         | 
| 18 | 
            +
                global pipeline
         | 
| 19 | 
            +
                pipeline = pipeline.to("cuda")
         | 
| 20 | 
            +
                if isinstance(img_list, Image.Image):
         | 
| 21 | 
            +
                    img_list = [img_list]
         | 
| 22 | 
            +
                img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
         | 
| 23 | 
            +
                ret = []
         | 
| 24 | 
            +
                for img in img_list:
         | 
| 25 | 
            +
                    images = trainer.pipeline_forward(
         | 
| 26 | 
            +
                        pipeline=pipeline,
         | 
| 27 | 
            +
                        image=img,
         | 
| 28 | 
            +
                        guidance_scale=guidance_scale, 
         | 
| 29 | 
            +
                        **kwargs
         | 
| 30 | 
            +
                    ).images
         | 
| 31 | 
            +
                    ret.extend(images)
         | 
| 32 | 
            +
                return ret
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
         | 
| 36 | 
            +
                if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
         | 
| 37 | 
            +
                    # still do remove using rembg, since simple_preprocess requires RGBA image
         | 
| 38 | 
            +
                    print("RGB image not RGBA! still remove bg!")
         | 
| 39 | 
            +
                    remove_bg = True
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                if remove_bg:
         | 
| 42 | 
            +
                    input_image = remove(input_image, session=session)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # make front_pil RGBA with white bg
         | 
| 45 | 
            +
                input_image = change_rgba_bg(input_image, "white")
         | 
| 46 | 
            +
                single_image = simple_preprocess(input_image)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                rgb_pils = predict(
         | 
| 51 | 
            +
                    single_image,
         | 
| 52 | 
            +
                    generator=generator,
         | 
| 53 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 54 | 
            +
                    width=256,
         | 
| 55 | 
            +
                    height=256,
         | 
| 56 | 
            +
                    num_inference_steps=30,
         | 
| 57 | 
            +
                )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                return rgb_pils, single_image
         | 
    	
        gradio_app/custom_models/normal_prediction.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            from gradio_app.utils import rgba_to_rgb, simple_remove
         | 
| 4 | 
            +
            from gradio_app.custom_models.utils import load_pipeline
         | 
| 5 | 
            +
            from scripts.utils import rotate_normals_torch
         | 
| 6 | 
            +
            from scripts.all_typing import *
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            training_config = "gradio_app/custom_models/image2normal.yaml"
         | 
| 9 | 
            +
            checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
         | 
| 10 | 
            +
            trainer, pipeline = load_pipeline(training_config, checkpoint_path)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):    
         | 
| 13 | 
            +
                global pipeline
         | 
| 14 | 
            +
                pipeline = pipeline.to("cuda")
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                img_list = image if isinstance(image, list) else [image]
         | 
| 17 | 
            +
                img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
         | 
| 18 | 
            +
                images = trainer.pipeline_forward(
         | 
| 19 | 
            +
                    pipeline=pipeline,
         | 
| 20 | 
            +
                    image=img_list,
         | 
| 21 | 
            +
                    num_inference_steps=num_inference_steps,
         | 
| 22 | 
            +
                    guidance_scale=guidance_scale, 
         | 
| 23 | 
            +
                    **kwargs
         | 
| 24 | 
            +
                ).images
         | 
| 25 | 
            +
                images = simple_remove(images)
         | 
| 26 | 
            +
                if do_rotate and len(images) > 1:
         | 
| 27 | 
            +
                    images = rotate_normals_torch(images, return_types='pil')
         | 
| 28 | 
            +
                return images
         | 
    	
        gradio_app/custom_models/utils.py
    ADDED
    
    | @@ -0,0 +1,75 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from gradio_app.utils import rgba_to_rgb
         | 
| 5 | 
            +
            from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
         | 
| 6 | 
            +
            from custum_3d_diffusion import modules
         | 
| 7 | 
            +
            from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
         | 
| 8 | 
            +
            from custum_3d_diffusion.trainings.base import BasicTrainer
         | 
| 9 | 
            +
            from custum_3d_diffusion.trainings.utils import load_config
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @dataclass
         | 
| 13 | 
            +
            class FakeAccelerator:
         | 
| 14 | 
            +
                device: torch.device = torch.device("cuda")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
         | 
| 18 | 
            +
                accelerator = FakeAccelerator()
         | 
| 19 | 
            +
                cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
         | 
| 20 | 
            +
                init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
         | 
| 21 | 
            +
                configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
         | 
| 22 | 
            +
                configurable_unet.enable_xformers_memory_efficient_attention()
         | 
| 23 | 
            +
                trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
         | 
| 24 | 
            +
                trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
         | 
| 25 | 
            +
                return trainers, configurable_unet
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from gradio_app.utils import make_image_grid, split_image
         | 
| 28 | 
            +
            def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
         | 
| 29 | 
            +
                from rembg import remove
         | 
| 30 | 
            +
                if remove_bg:
         | 
| 31 | 
            +
                    img = remove(img)
         | 
| 32 | 
            +
                img = rgba_to_rgb(img)
         | 
| 33 | 
            +
                if merged_image:
         | 
| 34 | 
            +
                    img = split_image(img, rows=2)
         | 
| 35 | 
            +
                images = function(
         | 
| 36 | 
            +
                    image=img,
         | 
| 37 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
                if len(images) > 1:
         | 
| 40 | 
            +
                    return make_image_grid(images, rows=2)
         | 
| 41 | 
            +
                else:
         | 
| 42 | 
            +
                    return images[0]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def process_text(trainer, pipeline, img, guidance_scale=2.):
         | 
| 46 | 
            +
                pipeline.cfg.validation_prompts = [img]
         | 
| 47 | 
            +
                titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
         | 
| 48 | 
            +
                return images[0]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
         | 
| 52 | 
            +
                training_config = config_path
         | 
| 53 | 
            +
                load_from_checkpoint = ckpt_path
         | 
| 54 | 
            +
                extras = []
         | 
| 55 | 
            +
                device = "cuda"
         | 
| 56 | 
            +
                trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
         | 
| 57 | 
            +
                shared_modules = dict()
         | 
| 58 | 
            +
                for trainer in trainers:
         | 
| 59 | 
            +
                    shared_modules = trainer.init_shared_modules(shared_modules)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                if load_from_checkpoint is not None:
         | 
| 62 | 
            +
                    state_dict = torch.load(load_from_checkpoint, map_location="cpu")
         | 
| 63 | 
            +
                    configurable_unet.unet.load_state_dict(state_dict, strict=False)
         | 
| 64 | 
            +
                # Move unet, vae and text_encoder to device and cast to weight_dtype
         | 
| 65 | 
            +
                configurable_unet.unet.to(device, dtype=weight_dtype)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                pipeline = None
         | 
| 68 | 
            +
                trainer_out = None
         | 
| 69 | 
            +
                for trainer in trainers:
         | 
| 70 | 
            +
                    if pipeline_filter(trainer.cfg.trainer_name):
         | 
| 71 | 
            +
                        pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
         | 
| 72 | 
            +
                        pipeline.set_progress_bar_config(disable=False)
         | 
| 73 | 
            +
                        trainer_out = trainer
         | 
| 74 | 
            +
                pipeline = pipeline.to(device, dtype=weight_dtype)
         | 
| 75 | 
            +
                return trainer_out, pipeline
         | 
    	
        gradio_app/examples/Groot.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/aaa.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/abma.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/akun.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/anya.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/bag.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/ex1.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/ex2.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/ex3.jpg
    ADDED
    
    |   | 
    	
        gradio_app/examples/ex4.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/generated_1715761545_frame0.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/generated_1715762357_frame0.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/generated_1715763329_frame0.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/hatsune_miku.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/examples/princess-large.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        gradio_app/gradio_3dgen.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from pytorch3d.structures import Meshes
         | 
| 6 | 
            +
            from gradio_app.utils import clean_up
         | 
| 7 | 
            +
            from gradio_app.custom_models.mvimg_prediction import run_mvprediction
         | 
| 8 | 
            +
            from gradio_app.custom_models.normal_prediction import predict_normals
         | 
| 9 | 
            +
            from scripts.refine_lr_to_sr import run_sr_fast
         | 
| 10 | 
            +
            from scripts.utils import save_glb_and_video
         | 
| 11 | 
            +
            # from scripts.multiview_inference import geo_reconstruct
         | 
| 12 | 
            +
            from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            @spaces.GPU(duration=100)
         | 
| 15 | 
            +
            def run_mv(preview_img, input_processing, seed):
         | 
| 16 | 
            +
                if preview_img.size[0] <= 512:
         | 
| 17 | 
            +
                    preview_img = run_sr_fast([preview_img])[0]
         | 
| 18 | 
            +
                rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
         | 
| 19 | 
            +
                return rgb_pils, front_pil
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            @spaces.GPU(duration=100) # seems split into multiple part will leads to `RuntimeError`, before fix it, still initialize here
         | 
| 22 | 
            +
            def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
         | 
| 23 | 
            +
                if preview_img is None:
         | 
| 24 | 
            +
                    raise gr.Error("The input image is none!")
         | 
| 25 | 
            +
                if isinstance(preview_img, str):
         | 
| 26 | 
            +
                    preview_img = Image.open(preview_img)
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                meshes = geo_reconstruct_part2(vertices, faces)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                new_meshes = geo_reconstruct_part3(meshes, img_list)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                vertices = new_meshes.verts_packed()
         | 
| 37 | 
            +
                vertices = vertices / 2 * 1.35
         | 
| 38 | 
            +
                vertices[..., [0, 2]] = - vertices[..., [0, 2]]
         | 
| 39 | 
            +
                new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
         | 
| 42 | 
            +
                return ret_mesh, video
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            #######################################
         | 
| 45 | 
            +
            def create_ui(concurrency_id="wkl"):
         | 
| 46 | 
            +
                with gr.Row():
         | 
| 47 | 
            +
                    with gr.Column(scale=1):
         | 
| 48 | 
            +
                        input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
         | 
| 49 | 
            +
                        
         | 
| 50 | 
            +
                        example_folder = os.path.join(os.path.dirname(__file__), "./examples")
         | 
| 51 | 
            +
                        example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
         | 
| 52 | 
            +
                        gr.Examples(
         | 
| 53 | 
            +
                            examples=example_fns,
         | 
| 54 | 
            +
                            inputs=[input_image],
         | 
| 55 | 
            +
                            cache_examples=False,
         | 
| 56 | 
            +
                            label='Examples',
         | 
| 57 | 
            +
                            examples_per_page=12
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
                        
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    with gr.Column(scale=1):
         | 
| 62 | 
            +
                        # export mesh display
         | 
| 63 | 
            +
                        output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320, camera_position=(90, 90, 2))
         | 
| 64 | 
            +
                        output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
         | 
| 65 | 
            +
                        
         | 
| 66 | 
            +
                        input_processing = gr.Checkbox(
         | 
| 67 | 
            +
                            value=True,
         | 
| 68 | 
            +
                            label='Remove Background',
         | 
| 69 | 
            +
                            visible=True,
         | 
| 70 | 
            +
                        )
         | 
| 71 | 
            +
                        do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
         | 
| 72 | 
            +
                        expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
         | 
| 73 | 
            +
                        init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
         | 
| 74 | 
            +
                        setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
         | 
| 75 | 
            +
                        render_video = gr.Checkbox(value=False, visible=False, label="generate video")
         | 
| 76 | 
            +
                        fullrunv2_btn = gr.Button('Generate 3D', variant = "primary", interactive=True)
         | 
| 77 | 
            +
                        
         | 
| 78 | 
            +
                fullrunv2_btn.click(
         | 
| 79 | 
            +
                    fn = generate3dv2,
         | 
| 80 | 
            +
                    inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
         | 
| 81 | 
            +
                    outputs=[output_mesh, output_video],
         | 
| 82 | 
            +
                    concurrency_id=concurrency_id,
         | 
| 83 | 
            +
                    api_name="generate3dv2",
         | 
| 84 | 
            +
                ).success(clean_up, api_name=False)
         | 
| 85 | 
            +
                return input_image
         | 
    	
        gradio_app/gradio_3dgen_steps.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from gradio_app.custom_models.mvimg_prediction import run_mvprediction
         | 
| 5 | 
            +
            from gradio_app.utils import make_image_grid, split_image
         | 
| 6 | 
            +
            from scripts.utils import save_glb_and_video
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
         | 
| 9 | 
            +
                seed = int(seed)
         | 
| 10 | 
            +
                if preview_img is None:
         | 
| 11 | 
            +
                    raise gr.Error("preview_img is none.")
         | 
| 12 | 
            +
                if isinstance(preview_img, str):
         | 
| 13 | 
            +
                    preview_img = Image.open(preview_img)
         | 
| 14 | 
            +
                
         | 
| 15 | 
            +
                rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
         | 
| 16 | 
            +
                rgb_pil = make_image_grid(rgb_pils, rows=2)
         | 
| 17 | 
            +
                return rgb_pil, front_pil
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def concept_to_multiview_ui(concurrency_id="wkl"):
         | 
| 20 | 
            +
                with gr.Row():
         | 
| 21 | 
            +
                    with gr.Column(scale=2):
         | 
| 22 | 
            +
                        preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
         | 
| 23 | 
            +
                        input_processing = gr.Checkbox(
         | 
| 24 | 
            +
                            value=True,
         | 
| 25 | 
            +
                            label='Remove Background',
         | 
| 26 | 
            +
                        )
         | 
| 27 | 
            +
                        seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
         | 
| 28 | 
            +
                        guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
         | 
| 29 | 
            +
                        run_btn = gr.Button('Generate Multiview', interactive=True)
         | 
| 30 | 
            +
                    with gr.Column(scale=3):
         | 
| 31 | 
            +
                        # export mesh display
         | 
| 32 | 
            +
                        output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
         | 
| 33 | 
            +
                        output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
         | 
| 34 | 
            +
                run_btn.click(
         | 
| 35 | 
            +
                    fn = concept_to_multiview,
         | 
| 36 | 
            +
                    inputs=[preview_img, input_processing, seed, guidance],
         | 
| 37 | 
            +
                    outputs=[output_rgb, output_front],
         | 
| 38 | 
            +
                    concurrency_id=concurrency_id,
         | 
| 39 | 
            +
                    api_name=False,
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                return output_rgb, output_front
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            from gradio_app.custom_models.normal_prediction import predict_normals
         | 
| 44 | 
            +
            from scripts.multiview_inference import geo_reconstruct
         | 
| 45 | 
            +
            def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
         | 
| 46 | 
            +
                rgb_pils = split_image(rgb_pil, rows=2)
         | 
| 47 | 
            +
                if normal_pil is not None:
         | 
| 48 | 
            +
                    normal_pil = split_image(normal_pil, rows=2)
         | 
| 49 | 
            +
                if front_pil is None:
         | 
| 50 | 
            +
                    front_pil = rgb_pils[0]
         | 
| 51 | 
            +
                new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
         | 
| 52 | 
            +
                ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
         | 
| 53 | 
            +
                return ret_mesh
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def new_multiview_to_mesh_ui(concurrency_id="wkl"):
         | 
| 56 | 
            +
                with gr.Row():
         | 
| 57 | 
            +
                    with gr.Column(scale=2):
         | 
| 58 | 
            +
                        rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
         | 
| 59 | 
            +
                        front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
         | 
| 60 | 
            +
                        normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
         | 
| 61 | 
            +
                        do_refine = gr.Checkbox(
         | 
| 62 | 
            +
                            value=False,
         | 
| 63 | 
            +
                            label='Refine rgb',
         | 
| 64 | 
            +
                            visible=False,
         | 
| 65 | 
            +
                        )
         | 
| 66 | 
            +
                        expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
         | 
| 67 | 
            +
                        init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
         | 
| 68 | 
            +
                        run_btn = gr.Button('Generate 3D', interactive=True)
         | 
| 69 | 
            +
                    with gr.Column(scale=3):
         | 
| 70 | 
            +
                        # export mesh display
         | 
| 71 | 
            +
                        output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
         | 
| 72 | 
            +
                run_btn.click(
         | 
| 73 | 
            +
                    fn = multiview_to_mesh_v2,
         | 
| 74 | 
            +
                    inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
         | 
| 75 | 
            +
                    outputs=[output_mesh],
         | 
| 76 | 
            +
                    concurrency_id=concurrency_id,
         | 
| 77 | 
            +
                    api_name="multiview_to_mesh",
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
                return rgb_pil, front_pil, output_mesh
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            #######################################
         | 
| 83 | 
            +
            def create_step_ui(concurrency_id="wkl"):
         | 
| 84 | 
            +
                with gr.Tab(label="3D:concept_to_multiview"):
         | 
| 85 | 
            +
                    concept_to_multiview_ui(concurrency_id)
         | 
| 86 | 
            +
                with gr.Tab(label="3D:new_multiview_to_mesh"):
         | 
| 87 | 
            +
                    new_multiview_to_mesh_ui(concurrency_id)
         | 
    	
        gradio_app/gradio_local.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            if __name__ == "__main__":
         | 
| 2 | 
            +
                import os
         | 
| 3 | 
            +
                import sys
         | 
| 4 | 
            +
                sys.path.append(os.curdir)
         | 
| 5 | 
            +
                if 'CUDA_VISIBLE_DEVICES' not in os.environ:
         | 
| 6 | 
            +
                    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
         | 
| 7 | 
            +
                os.environ['TRANSFORMERS_OFFLINE']='0'
         | 
| 8 | 
            +
                os.environ['DIFFUSERS_OFFLINE']='0'
         | 
| 9 | 
            +
                os.environ['HF_HUB_OFFLINE']='0'
         | 
| 10 | 
            +
                os.environ['GRADIO_ANALYTICS_ENABLED']='False'
         | 
| 11 | 
            +
                os.environ['HF_ENDPOINT']='https://hf-mirror.com'
         | 
| 12 | 
            +
                import torch
         | 
| 13 | 
            +
                torch.set_float32_matmul_precision('medium')
         | 
| 14 | 
            +
                torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 15 | 
            +
                torch.set_grad_enabled(False)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import gradio as gr
         | 
| 18 | 
            +
            import argparse
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from gradio_app.gradio_3dgen import create_ui as create_3d_ui
         | 
| 21 | 
            +
            # from app.gradio_3dgen_steps import create_step_ui
         | 
| 22 | 
            +
            from gradio_app.all_models import model_zoo
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
         | 
| 26 | 
            +
            _DESCRIPTION = '''
         | 
| 27 | 
            +
            [Project page](https://wukailu.github.io/Unique3D/)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            * The demo is still under construction, and more features are expected to be implemented soon.
         | 
| 32 | 
            +
            '''
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def launch(
         | 
| 35 | 
            +
                port,
         | 
| 36 | 
            +
                listen=False,
         | 
| 37 | 
            +
                share=False,
         | 
| 38 | 
            +
                gradio_root="",
         | 
| 39 | 
            +
            ):
         | 
| 40 | 
            +
                model_zoo.init_models()
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                with gr.Blocks(
         | 
| 43 | 
            +
                    title=_TITLE,
         | 
| 44 | 
            +
                    theme=gr.themes.Monochrome(),
         | 
| 45 | 
            +
                ) as demo:
         | 
| 46 | 
            +
                    with gr.Row():
         | 
| 47 | 
            +
                        with gr.Column(scale=1):
         | 
| 48 | 
            +
                            gr.Markdown('# ' + _TITLE)
         | 
| 49 | 
            +
                    gr.Markdown(_DESCRIPTION)
         | 
| 50 | 
            +
                    create_3d_ui("wkl")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                launch_args = {}
         | 
| 53 | 
            +
                if listen:
         | 
| 54 | 
            +
                    launch_args["server_name"] = "0.0.0.0"
         | 
| 55 | 
            +
                    
         | 
| 56 | 
            +
                demo.queue(default_concurrency_limit=1).launch(
         | 
| 57 | 
            +
                    server_port=None if port == 0 else port,
         | 
| 58 | 
            +
                    share=share,
         | 
| 59 | 
            +
                    root_path=gradio_root if gradio_root != "" else None,  # "/myapp"
         | 
| 60 | 
            +
                    **launch_args,
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            if __name__ == "__main__":
         | 
| 64 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 65 | 
            +
                args, extra = parser.parse_known_args()
         | 
| 66 | 
            +
                parser.add_argument("--listen", action="store_true")
         | 
| 67 | 
            +
                parser.add_argument("--port", type=int, default=0)
         | 
| 68 | 
            +
                parser.add_argument("--share", action="store_true")
         | 
| 69 | 
            +
                parser.add_argument("--gradio_root", default="")
         | 
| 70 | 
            +
                args = parser.parse_args()
         | 
| 71 | 
            +
                launch(
         | 
| 72 | 
            +
                    args.port,
         | 
| 73 | 
            +
                    listen=args.listen,
         | 
| 74 | 
            +
                    share=args.share,
         | 
| 75 | 
            +
                    gradio_root=args.gradio_root,
         | 
| 76 | 
            +
                )
         | 
    	
        gradio_app/utils.py
    ADDED
    
    | @@ -0,0 +1,112 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            import gc
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from scripts.refine_lr_to_sr import run_sr_fast
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            GRADIO_CACHE = "/tmp/gradio/"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def clean_up():
         | 
| 13 | 
            +
                torch.cuda.empty_cache()
         | 
| 14 | 
            +
                gc.collect()
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def remove_color(arr):
         | 
| 17 | 
            +
                if arr.shape[-1] == 4:
         | 
| 18 | 
            +
                    arr = arr[..., :3]
         | 
| 19 | 
            +
                # calc diffs
         | 
| 20 | 
            +
                base = arr[0, 0]
         | 
| 21 | 
            +
                diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
         | 
| 22 | 
            +
                alpha = (diffs <= 80)
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                arr[alpha] = 255
         | 
| 25 | 
            +
                alpha = ~alpha
         | 
| 26 | 
            +
                arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
         | 
| 27 | 
            +
                return arr
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            def simple_remove(imgs, run_sr=True):
         | 
| 30 | 
            +
                """Only works for normal"""
         | 
| 31 | 
            +
                if not isinstance(imgs, list):
         | 
| 32 | 
            +
                    imgs = [imgs]
         | 
| 33 | 
            +
                    single_input = True
         | 
| 34 | 
            +
                else:
         | 
| 35 | 
            +
                    single_input = False
         | 
| 36 | 
            +
                if run_sr:
         | 
| 37 | 
            +
                    imgs = run_sr_fast(imgs)
         | 
| 38 | 
            +
                rets = []
         | 
| 39 | 
            +
                for img in imgs:
         | 
| 40 | 
            +
                    arr = np.array(img)
         | 
| 41 | 
            +
                    arr = remove_color(arr)
         | 
| 42 | 
            +
                    rets.append(Image.fromarray(arr.astype(np.uint8)))
         | 
| 43 | 
            +
                if single_input:
         | 
| 44 | 
            +
                    return rets[0]
         | 
| 45 | 
            +
                return rets
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
         | 
| 48 | 
            +
                new_image = Image.new("RGBA", rgba.size, bkgd)
         | 
| 49 | 
            +
                new_image.paste(rgba, (0, 0), rgba)
         | 
| 50 | 
            +
                new_image = new_image.convert('RGB')
         | 
| 51 | 
            +
                return new_image
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
         | 
| 54 | 
            +
                rgb_white = rgba_to_rgb(rgba, bkgd)
         | 
| 55 | 
            +
                new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
         | 
| 56 | 
            +
                return new_rgba
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def split_image(image, rows=None, cols=None):
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                    inverse function of make_image_grid
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                # image is in square
         | 
| 63 | 
            +
                if rows is None and cols is None:
         | 
| 64 | 
            +
                    # image.size [W, H]
         | 
| 65 | 
            +
                    rows = 1
         | 
| 66 | 
            +
                    cols = image.size[0] // image.size[1]
         | 
| 67 | 
            +
                    assert cols * image.size[1] == image.size[0]
         | 
| 68 | 
            +
                    subimg_size = image.size[1]
         | 
| 69 | 
            +
                elif rows is None:
         | 
| 70 | 
            +
                    subimg_size = image.size[0] // cols
         | 
| 71 | 
            +
                    rows = image.size[1] // subimg_size
         | 
| 72 | 
            +
                    assert rows * subimg_size == image.size[1]
         | 
| 73 | 
            +
                elif cols is None:
         | 
| 74 | 
            +
                    subimg_size = image.size[1] // rows
         | 
| 75 | 
            +
                    cols = image.size[0] // subimg_size
         | 
| 76 | 
            +
                    assert cols * subimg_size == image.size[0]
         | 
| 77 | 
            +
                else:
         | 
| 78 | 
            +
                    subimg_size = image.size[1] // rows
         | 
| 79 | 
            +
                    assert cols * subimg_size == image.size[0]
         | 
| 80 | 
            +
                subimgs = []
         | 
| 81 | 
            +
                for i in range(rows):
         | 
| 82 | 
            +
                    for j in range(cols):
         | 
| 83 | 
            +
                        subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
         | 
| 84 | 
            +
                        subimgs.append(subimg)
         | 
| 85 | 
            +
                return subimgs
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            def make_image_grid(images, rows=None, cols=None, resize=None):
         | 
| 88 | 
            +
                if rows is None and cols is None:
         | 
| 89 | 
            +
                    rows = 1
         | 
| 90 | 
            +
                    cols = len(images)
         | 
| 91 | 
            +
                if rows is None:
         | 
| 92 | 
            +
                    rows = len(images) // cols
         | 
| 93 | 
            +
                    if len(images) % cols != 0:
         | 
| 94 | 
            +
                        rows += 1
         | 
| 95 | 
            +
                if cols is None:
         | 
| 96 | 
            +
                    cols = len(images) // rows
         | 
| 97 | 
            +
                    if len(images) % rows != 0:
         | 
| 98 | 
            +
                        cols += 1
         | 
| 99 | 
            +
                total_imgs = rows * cols
         | 
| 100 | 
            +
                if total_imgs > len(images):
         | 
| 101 | 
            +
                    images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                if resize is not None:
         | 
| 104 | 
            +
                    images = [img.resize((resize, resize)) for img in images]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                w, h = images[0].size
         | 
| 107 | 
            +
                grid = Image.new(images[0].mode, size=(cols * w, rows * h))
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                for i, img in enumerate(images):
         | 
| 110 | 
            +
                    grid.paste(img, box=(i % cols * w, i // cols * h))
         | 
| 111 | 
            +
                return grid
         | 
| 112 | 
            +
             | 
    	
        mesh_reconstruction/func.py
    ADDED
    
    | @@ -0,0 +1,133 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # modified from https://github.com/Profactor/continuous-remeshing
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import trimesh
         | 
| 5 | 
            +
            from typing import Tuple
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def to_numpy(*args):
         | 
| 8 | 
            +
                def convert(a):
         | 
| 9 | 
            +
                    if isinstance(a,torch.Tensor):
         | 
| 10 | 
            +
                        return a.detach().cpu().numpy()
         | 
| 11 | 
            +
                    assert a is None or isinstance(a,np.ndarray)
         | 
| 12 | 
            +
                    return a
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def laplacian(
         | 
| 17 | 
            +
                    num_verts:int,
         | 
| 18 | 
            +
                    edges: torch.Tensor #E,2
         | 
| 19 | 
            +
                    ) -> torch.Tensor: #sparse V,V
         | 
| 20 | 
            +
                """create sparse Laplacian matrix"""
         | 
| 21 | 
            +
                V = num_verts
         | 
| 22 | 
            +
                E = edges.shape[0]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                #adjacency matrix,
         | 
| 25 | 
            +
                idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T  # (2, 2*E)
         | 
| 26 | 
            +
                ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
         | 
| 27 | 
            +
                A = torch.sparse.FloatTensor(idx, ones, (V, V))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                #degree matrix
         | 
| 30 | 
            +
                deg = torch.sparse.sum(A, dim=1).to_dense()
         | 
| 31 | 
            +
                idx = torch.arange(V, device=edges.device)
         | 
| 32 | 
            +
                idx = torch.stack([idx, idx], dim=0)
         | 
| 33 | 
            +
                D = torch.sparse.FloatTensor(idx, deg, (V, V))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return D - A
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def _translation(x, y, z, device):
         | 
| 38 | 
            +
                return torch.tensor([[1., 0, 0, x],
         | 
| 39 | 
            +
                                [0, 1, 0, y],
         | 
| 40 | 
            +
                                [0, 0, 1, z],
         | 
| 41 | 
            +
                                [0, 0, 0, 1]],device=device) #4,4
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                    see https://blog.csdn.net/wodownload2/article/details/85069240/
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                if l is None:
         | 
| 48 | 
            +
                    l = -r
         | 
| 49 | 
            +
                if t is None:
         | 
| 50 | 
            +
                    t = r
         | 
| 51 | 
            +
                if b is None:
         | 
| 52 | 
            +
                    b = -t
         | 
| 53 | 
            +
                p = torch.zeros([4,4],device=device)
         | 
| 54 | 
            +
                p[0,0] = 2*n/(r-l)
         | 
| 55 | 
            +
                p[0,2] = (r+l)/(r-l)
         | 
| 56 | 
            +
                p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
         | 
| 57 | 
            +
                p[1,2] = (t+b)/(t-b)
         | 
| 58 | 
            +
                p[2,2] = -(f+n)/(f-n)
         | 
| 59 | 
            +
                p[2,3] = -(2*f*n)/(f-n)
         | 
| 60 | 
            +
                p[3,2] = -1
         | 
| 61 | 
            +
                return p #4,4
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
         | 
| 64 | 
            +
                if l is None:
         | 
| 65 | 
            +
                    l = -r
         | 
| 66 | 
            +
                if t is None:
         | 
| 67 | 
            +
                    t = r
         | 
| 68 | 
            +
                if b is None:
         | 
| 69 | 
            +
                    b = -t
         | 
| 70 | 
            +
                o = torch.zeros([4,4],device=device)
         | 
| 71 | 
            +
                o[0,0] = 2/(r-l)
         | 
| 72 | 
            +
                o[0,3] = -(r+l)/(r-l)
         | 
| 73 | 
            +
                o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
         | 
| 74 | 
            +
                o[1,3] = -(t+b)/(t-b)
         | 
| 75 | 
            +
                o[2,2] = -2/(f-n)
         | 
| 76 | 
            +
                o[2,3] = -(f+n)/(f-n)
         | 
| 77 | 
            +
                o[3,3] = 1
         | 
| 78 | 
            +
                return o #4,4
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
         | 
| 81 | 
            +
                if r is None:
         | 
| 82 | 
            +
                    r = 1/distance
         | 
| 83 | 
            +
                A = az_count
         | 
| 84 | 
            +
                P = pol_count
         | 
| 85 | 
            +
                C = A * P
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                phi = torch.arange(0,A) * (2*torch.pi/A)
         | 
| 88 | 
            +
                phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
         | 
| 89 | 
            +
                phi_rot[:,0,2,2] = phi.cos()
         | 
| 90 | 
            +
                phi_rot[:,0,2,0] = -phi.sin()
         | 
| 91 | 
            +
                phi_rot[:,0,0,2] = phi.sin()
         | 
| 92 | 
            +
                phi_rot[:,0,0,0] = phi.cos()
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
         | 
| 95 | 
            +
                theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
         | 
| 96 | 
            +
                theta_rot[0,:,1,1] = theta.cos()
         | 
| 97 | 
            +
                theta_rot[0,:,1,2] = -theta.sin()
         | 
| 98 | 
            +
                theta_rot[0,:,2,1] = theta.sin()
         | 
| 99 | 
            +
                theta_rot[0,:,2,2] = theta.cos()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                mv = torch.empty((C,4,4), device=device)
         | 
| 102 | 
            +
                mv[:] = torch.eye(4, device=device)
         | 
| 103 | 
            +
                mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
         | 
| 104 | 
            +
                mv = _translation(0, 0, -distance, device) @ mv
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                return mv, _projection(r,device)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
         | 
| 109 | 
            +
                mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
         | 
| 110 | 
            +
                if r is None:
         | 
| 111 | 
            +
                    r = 1
         | 
| 112 | 
            +
                return mv, _orthographic(r,device)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
         | 
| 115 | 
            +
                sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
         | 
| 116 | 
            +
                vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
         | 
| 117 | 
            +
                faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
         | 
| 118 | 
            +
                return vertices,faces
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            from pytorch3d.renderer import (
         | 
| 121 | 
            +
                FoVOrthographicCameras,
         | 
| 122 | 
            +
                look_at_view_transform,
         | 
| 123 | 
            +
            )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def get_camera(R, T, focal_length=1 / (2**0.5)):
         | 
| 126 | 
            +
                focal_length = 1 / focal_length
         | 
| 127 | 
            +
                camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
         | 
| 128 | 
            +
                return camera
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
         | 
| 131 | 
            +
                R, T = look_at_view_transform(dist, 0, azim_list)
         | 
| 132 | 
            +
                focal_length = 1 / focal
         | 
| 133 | 
            +
                return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
         | 
    	
        mesh_reconstruction/opt.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # modified from https://github.com/Profactor/continuous-remeshing
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch_scatter
         | 
| 5 | 
            +
            from typing import Tuple
         | 
| 6 | 
            +
            from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @torch.no_grad()
         | 
| 9 | 
            +
            def remesh(
         | 
| 10 | 
            +
                    vertices_etc:torch.Tensor, #V,D
         | 
| 11 | 
            +
                    faces:torch.Tensor, #F,3 long
         | 
| 12 | 
            +
                    min_edgelen:torch.Tensor, #V
         | 
| 13 | 
            +
                    max_edgelen:torch.Tensor, #V
         | 
| 14 | 
            +
                    flip:bool,
         | 
| 15 | 
            +
                    max_vertices=1e6
         | 
| 16 | 
            +
                    ):
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                # dummies
         | 
| 19 | 
            +
                vertices_etc,faces = prepend_dummies(vertices_etc,faces)
         | 
| 20 | 
            +
                vertices = vertices_etc[:,:3] #V,3
         | 
| 21 | 
            +
                nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
         | 
| 22 | 
            +
                min_edgelen = torch.concat((nan_tensor,min_edgelen))
         | 
| 23 | 
            +
                max_edgelen = torch.concat((nan_tensor,max_edgelen))
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # collapse
         | 
| 26 | 
            +
                edges,face_to_edge = calc_edges(faces) #E,2 F,3
         | 
| 27 | 
            +
                edge_length = calc_edge_length(vertices,edges) #E
         | 
| 28 | 
            +
                face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
         | 
| 29 | 
            +
                vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
         | 
| 30 | 
            +
                face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
         | 
| 31 | 
            +
                shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
         | 
| 32 | 
            +
                priority = face_collapse.float() + shortness
         | 
| 33 | 
            +
                vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # split
         | 
| 36 | 
            +
                if vertices.shape[0]<max_vertices:
         | 
| 37 | 
            +
                    edges,face_to_edge = calc_edges(faces) #E,2 F,3
         | 
| 38 | 
            +
                    vertices = vertices_etc[:,:3] #V,3
         | 
| 39 | 
            +
                    edge_length = calc_edge_length(vertices,edges) #E
         | 
| 40 | 
            +
                    splits = edge_length > max_edgelen[edges].mean(dim=-1)
         | 
| 41 | 
            +
                    vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                vertices_etc,faces = pack(vertices_etc,faces)
         | 
| 44 | 
            +
                vertices = vertices_etc[:,:3]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                if flip:
         | 
| 47 | 
            +
                    edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
         | 
| 48 | 
            +
                    flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                return remove_dummies(vertices_etc,faces)
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
            def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
         | 
| 53 | 
            +
                """lerp with adam's bias correction"""
         | 
| 54 | 
            +
                c_prev = 1-weight**(step-1)
         | 
| 55 | 
            +
                c = 1-weight**step
         | 
| 56 | 
            +
                a_weight = weight*c_prev/c
         | 
| 57 | 
            +
                b_weight = (1-weight)/c
         | 
| 58 | 
            +
                a.mul_(a_weight).add_(b, alpha=b_weight)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class MeshOptimizer:
         | 
| 62 | 
            +
                """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __init__(self, 
         | 
| 65 | 
            +
                        vertices:torch.Tensor, #V,3
         | 
| 66 | 
            +
                        faces:torch.Tensor, #F,3
         | 
| 67 | 
            +
                        lr=0.3, #learning rate
         | 
| 68 | 
            +
                        betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
         | 
| 69 | 
            +
                        gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
         | 
| 70 | 
            +
                        nu_ref=0.3, #reference velocity for edge length controller
         | 
| 71 | 
            +
                        edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
         | 
| 72 | 
            +
                        edge_len_tol=.5, #edge length tolerance for split and collapse
         | 
| 73 | 
            +
                        gain=.2,  #gain value for edge length controller
         | 
| 74 | 
            +
                        laplacian_weight=.02, #for laplacian smoothing/regularization
         | 
| 75 | 
            +
                        ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])            
         | 
| 76 | 
            +
                        grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
         | 
| 77 | 
            +
                        remesh_interval=1, #larger intervals are faster but with worse mesh quality
         | 
| 78 | 
            +
                        local_edgelen=True, #set to False to use a global scalar reference edge length instead
         | 
| 79 | 
            +
                        ):
         | 
| 80 | 
            +
                    self._vertices = vertices
         | 
| 81 | 
            +
                    self._faces = faces
         | 
| 82 | 
            +
                    self._lr = lr
         | 
| 83 | 
            +
                    self._betas = betas
         | 
| 84 | 
            +
                    self._gammas = gammas
         | 
| 85 | 
            +
                    self._nu_ref = nu_ref
         | 
| 86 | 
            +
                    self._edge_len_lims = edge_len_lims
         | 
| 87 | 
            +
                    self._edge_len_tol = edge_len_tol
         | 
| 88 | 
            +
                    self._gain = gain
         | 
| 89 | 
            +
                    self._laplacian_weight = laplacian_weight
         | 
| 90 | 
            +
                    self._ramp = ramp
         | 
| 91 | 
            +
                    self._grad_lim = grad_lim
         | 
| 92 | 
            +
                    self._remesh_interval = remesh_interval
         | 
| 93 | 
            +
                    self._local_edgelen = local_edgelen
         | 
| 94 | 
            +
                    self._step = 0
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    V = self._vertices.shape[0]
         | 
| 97 | 
            +
                    # prepare continuous tensor for all vertex-based data 
         | 
| 98 | 
            +
                    self._vertices_etc = torch.zeros([V,9],device=vertices.device)
         | 
| 99 | 
            +
                    self._split_vertices_etc()
         | 
| 100 | 
            +
                    self.vertices.copy_(vertices) #initialize vertices
         | 
| 101 | 
            +
                    self._vertices.requires_grad_()
         | 
| 102 | 
            +
                    self._ref_len.fill_(edge_len_lims[1])
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                @property
         | 
| 105 | 
            +
                def vertices(self):
         | 
| 106 | 
            +
                    return self._vertices
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                @property
         | 
| 109 | 
            +
                def faces(self):
         | 
| 110 | 
            +
                    return self._faces
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def _split_vertices_etc(self):
         | 
| 113 | 
            +
                    self._vertices = self._vertices_etc[:,:3]
         | 
| 114 | 
            +
                    self._m2 = self._vertices_etc[:,3]
         | 
| 115 | 
            +
                    self._nu = self._vertices_etc[:,4]
         | 
| 116 | 
            +
                    self._m1 = self._vertices_etc[:,5:8]
         | 
| 117 | 
            +
                    self._ref_len = self._vertices_etc[:,8]
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    with_gammas = any(g!=0 for g in self._gammas)
         | 
| 120 | 
            +
                    self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def zero_grad(self):
         | 
| 123 | 
            +
                    self._vertices.grad = None
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                @torch.no_grad()
         | 
| 126 | 
            +
                def step(self):
         | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                    eps = 1e-8
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self._step += 1
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # spatial smoothing
         | 
| 133 | 
            +
                    edges,_ = calc_edges(self._faces) #E,2
         | 
| 134 | 
            +
                    E = edges.shape[0]
         | 
| 135 | 
            +
                    edge_smooth = self._smooth[edges] #E,2,S
         | 
| 136 | 
            +
                    neighbor_smooth = torch.zeros_like(self._smooth) #V,S
         | 
| 137 | 
            +
                    torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    #apply optional smoothing of m1,m2,nu
         | 
| 140 | 
            +
                    if self._gammas[0]:
         | 
| 141 | 
            +
                        self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
         | 
| 142 | 
            +
                    if self._gammas[1]:
         | 
| 143 | 
            +
                        self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
         | 
| 144 | 
            +
                    if self._gammas[2]:
         | 
| 145 | 
            +
                        self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    #add laplace smoothing to gradients
         | 
| 148 | 
            +
                    laplace = self._vertices - neighbor_smooth[:,:3]
         | 
| 149 | 
            +
                    grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    #gradient clipping
         | 
| 152 | 
            +
                    if self._step>1:
         | 
| 153 | 
            +
                        grad_lim = self._m1.abs().mul_(self._grad_lim)
         | 
| 154 | 
            +
                        grad.clamp_(min=-grad_lim,max=grad_lim)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # moment updates
         | 
| 157 | 
            +
                    lerp_unbiased(self._m1, grad, self._betas[0], self._step)
         | 
| 158 | 
            +
                    lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
         | 
| 161 | 
            +
                    speed = velocity.norm(dim=-1) #V
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if self._betas[2]:
         | 
| 164 | 
            +
                        lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
         | 
| 165 | 
            +
                    else:
         | 
| 166 | 
            +
                        self._nu.copy_(speed) #V
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # update vertices
         | 
| 169 | 
            +
                    ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
         | 
| 170 | 
            +
                    self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # update target edge length
         | 
| 173 | 
            +
                    if self._step % self._remesh_interval == 0:
         | 
| 174 | 
            +
                        if self._local_edgelen:
         | 
| 175 | 
            +
                            len_change = (1 + (self._nu - self._nu_ref) * self._gain)
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
         | 
| 178 | 
            +
                        self._ref_len *= len_change
         | 
| 179 | 
            +
                        self._ref_len.clamp_(*self._edge_len_lims)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
         | 
| 182 | 
            +
                    min_edge_len = self._ref_len * (1 - self._edge_len_tol)
         | 
| 183 | 
            +
                    max_edge_len = self._ref_len * (1 + self._edge_len_tol)
         | 
| 184 | 
            +
                    
         | 
| 185 | 
            +
                    self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    self._split_vertices_etc()
         | 
| 188 | 
            +
                    self._vertices.requires_grad_()
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                    return self._vertices, self._faces
         | 
    	
        mesh_reconstruction/recon.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from tqdm import tqdm
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from typing import List
         | 
| 6 | 
            +
            from mesh_reconstruction.remesh import calc_vertex_normals
         | 
| 7 | 
            +
            from mesh_reconstruction.opt import MeshOptimizer
         | 
| 8 | 
            +
            from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
         | 
| 9 | 
            +
            from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
         | 
| 10 | 
            +
            from scripts.utils import to_py3d_mesh, init_target
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
         | 
| 13 | 
            +
                vertices, faces = vertices.to("cuda"), faces.to("cuda")
         | 
| 14 | 
            +
                assert len(pils) == 4
         | 
| 15 | 
            +
                mv,proj = make_star_cameras_orthographic(4, 1)          
         | 
| 16 | 
            +
                renderer = NormalsRenderer(mv,proj,list(pils[0].size))
         | 
| 17 | 
            +
                # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
         | 
| 18 | 
            +
                # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
         | 
| 21 | 
            +
                # 1. no rotate
         | 
| 22 | 
            +
                target_images = target_images[[0, 3, 2, 1]]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # 2. init from coarse mesh
         | 
| 25 | 
            +
                opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                vertices = opt.vertices
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                mask = target_images[..., -1] < 0.5
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                for i in tqdm(range(steps)):
         | 
| 32 | 
            +
                    opt.zero_grad()
         | 
| 33 | 
            +
                    opt._lr *= decay
         | 
| 34 | 
            +
                    normals = calc_vertex_normals(vertices,faces)
         | 
| 35 | 
            +
                    images = renderer.render(vertices,normals,faces)
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
         | 
| 38 | 
            +
                    
         | 
| 39 | 
            +
                    t_mask = images[..., -1] > 0.5
         | 
| 40 | 
            +
                    loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
         | 
| 41 | 
            +
                    loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                    loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    # out of box
         | 
| 46 | 
            +
                    loss_oob = (vertices.abs() > 0.99).float().mean() * 10
         | 
| 47 | 
            +
                    loss = loss + loss_oob
         | 
| 48 | 
            +
                    
         | 
| 49 | 
            +
                    loss.backward()
         | 
| 50 | 
            +
                    opt.step()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    vertices,faces = opt.remesh(poisson=False)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
         | 
| 55 | 
            +
                
         | 
| 56 | 
            +
                if return_mesh:
         | 
| 57 | 
            +
                    return to_py3d_mesh(vertices, faces)
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    return vertices, faces
         | 
    	
        mesh_reconstruction/refine.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from tqdm import tqdm
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            from mesh_reconstruction.remesh import calc_vertex_normals
         | 
| 6 | 
            +
            from mesh_reconstruction.opt import MeshOptimizer
         | 
| 7 | 
            +
            from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
         | 
| 8 | 
            +
            from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
         | 
| 9 | 
            +
            from scripts.project_mesh import multiview_color_projection, get_cameras_list
         | 
| 10 | 
            +
            from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True):
         | 
| 13 | 
            +
                vertices, faces = vertices.to("cuda"), faces.to("cuda")
         | 
| 14 | 
            +
                if process_inputs:
         | 
| 15 | 
            +
                    vertices = vertices * 2 / 1.35
         | 
| 16 | 
            +
                    vertices[..., [0, 2]] = - vertices[..., [0, 2]]
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                poission_steps = []
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                assert len(pils) == 4
         | 
| 21 | 
            +
                mv,proj = make_star_cameras_orthographic(4, 1)
         | 
| 22 | 
            +
                renderer = NormalsRenderer(mv,proj,list(pils[0].size))
         | 
| 23 | 
            +
                # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)     
         | 
| 24 | 
            +
                # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
         | 
| 27 | 
            +
                # 1. no rotate
         | 
| 28 | 
            +
                target_images = target_images[[0, 3, 2, 1]]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # 2. init from coarse mesh
         | 
| 31 | 
            +
                opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                vertices = opt.vertices
         | 
| 34 | 
            +
                alpha_init = None
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                mask = target_images[..., -1] < 0.5
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                for i in tqdm(range(steps)):
         | 
| 39 | 
            +
                    opt.zero_grad()
         | 
| 40 | 
            +
                    opt._lr *= decay
         | 
| 41 | 
            +
                    normals = calc_vertex_normals(vertices,faces)
         | 
| 42 | 
            +
                    images = renderer.render(vertices,normals,faces)
         | 
| 43 | 
            +
                    if alpha_init is None:
         | 
| 44 | 
            +
                        alpha_init = images.detach()
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                    if i < update_warmup or i % update_normal_interval == 0:
         | 
| 47 | 
            +
                        with torch.no_grad():
         | 
| 48 | 
            +
                            py3d_mesh = to_py3d_mesh(vertices, faces, normals)
         | 
| 49 | 
            +
                            cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.)
         | 
| 50 | 
            +
                            _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
         | 
| 51 | 
            +
                            target_normal = target_normal * 2 - 1
         | 
| 52 | 
            +
                            target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
         | 
| 53 | 
            +
                            debug_images = renderer.render(vertices,target_normal,faces)
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    d_mask = images[..., -1] > 0.5
         | 
| 56 | 
            +
                    loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    loss = loss_debug_l2 + loss_alpha_target_mask_l2
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    # out of box
         | 
| 63 | 
            +
                    loss_oob = (vertices.abs() > 0.99).float().mean() * 10
         | 
| 64 | 
            +
                    loss = loss + loss_oob
         | 
| 65 | 
            +
                    
         | 
| 66 | 
            +
                    loss.backward()
         | 
| 67 | 
            +
                    opt.step()
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    vertices,faces = opt.remesh(poisson=(i in poission_steps))
         | 
| 70 | 
            +
                
         | 
| 71 | 
            +
                vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                if process_outputs:
         | 
| 74 | 
            +
                    vertices = vertices / 2 * 1.35
         | 
| 75 | 
            +
                    vertices[..., [0, 2]] = - vertices[..., [0, 2]]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                if return_mesh:
         | 
| 78 | 
            +
                    return to_py3d_mesh(vertices, faces)
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    return vertices, faces
         | 
    	
        mesh_reconstruction/remesh.py
    ADDED
    
    | @@ -0,0 +1,361 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # modified from https://github.com/Profactor/continuous-remeshing
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as tfunc
         | 
| 4 | 
            +
            import torch_scatter
         | 
| 5 | 
            +
            from typing import Tuple
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def prepend_dummies(
         | 
| 8 | 
            +
                    vertices:torch.Tensor, #V,D
         | 
| 9 | 
            +
                    faces:torch.Tensor, #F,3 long
         | 
| 10 | 
            +
                )->Tuple[torch.Tensor,torch.Tensor]:
         | 
| 11 | 
            +
                """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
         | 
| 12 | 
            +
                V,D = vertices.shape
         | 
| 13 | 
            +
                vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
         | 
| 14 | 
            +
                faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
         | 
| 15 | 
            +
                return vertices,faces
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def remove_dummies(
         | 
| 18 | 
            +
                    vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
         | 
| 19 | 
            +
                    faces:torch.Tensor, #F,3 long - first face all zeros
         | 
| 20 | 
            +
                )->Tuple[torch.Tensor,torch.Tensor]:
         | 
| 21 | 
            +
                """remove dummy elements added with prepend_dummies()"""
         | 
| 22 | 
            +
                return vertices[1:],faces[1:]-1
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def calc_edges(
         | 
| 26 | 
            +
                    faces: torch.Tensor,  # F,3 long - first face may be dummy with all zeros
         | 
| 27 | 
            +
                    with_edge_to_face: bool = False
         | 
| 28 | 
            +
                ) -> Tuple[torch.Tensor, ...]:
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                returns Tuple of
         | 
| 31 | 
            +
                - edges E,2 long, 0 for unused, lower vertex index first
         | 
| 32 | 
            +
                - face_to_edge F,3 long
         | 
| 33 | 
            +
                - (optional) edge_to_face shape=E,[left,right],[face,side]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                o-<-----e1     e0,e1...edge, e0<e1
         | 
| 36 | 
            +
                |      /A      L,R....left and right face
         | 
| 37 | 
            +
                |  L /  |      both triangles ordered counter clockwise
         | 
| 38 | 
            +
                |  / R  |      normals pointing out of screen
         | 
| 39 | 
            +
                V/      |      
         | 
| 40 | 
            +
                e0---->-o     
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                F = faces.shape[0]
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                # make full edges, lower vertex index first
         | 
| 46 | 
            +
                face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
         | 
| 47 | 
            +
                full_edges = face_edges.reshape(F*3,2)
         | 
| 48 | 
            +
                sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # make unique edges
         | 
| 51 | 
            +
                edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
         | 
| 52 | 
            +
                E = edges.shape[0]
         | 
| 53 | 
            +
                face_to_edge = full_to_unique.reshape(F,3) #F,3
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if not with_edge_to_face:
         | 
| 56 | 
            +
                    return edges, face_to_edge
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
         | 
| 59 | 
            +
                edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
         | 
| 60 | 
            +
                scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
         | 
| 61 | 
            +
                edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
         | 
| 62 | 
            +
                edge_to_face[0] = 0
         | 
| 63 | 
            +
                return edges, face_to_edge, edge_to_face
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def calc_edge_length(
         | 
| 66 | 
            +
                    vertices:torch.Tensor, #V,3 first may be dummy
         | 
| 67 | 
            +
                    edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
         | 
| 68 | 
            +
                    )->torch.Tensor: #E
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                full_vertices = vertices[edges] #E,2,3
         | 
| 71 | 
            +
                a,b = full_vertices.unbind(dim=1) #E,3
         | 
| 72 | 
            +
                return torch.norm(a-b,p=2,dim=-1)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            def calc_face_normals(
         | 
| 75 | 
            +
                    vertices:torch.Tensor, #V,3 first vertex may be unreferenced
         | 
| 76 | 
            +
                    faces:torch.Tensor, #F,3 long, first face may be all zero
         | 
| 77 | 
            +
                    normalize:bool=False,
         | 
| 78 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                     n
         | 
| 81 | 
            +
                     |
         | 
| 82 | 
            +
                     c0     corners ordered counterclockwise when
         | 
| 83 | 
            +
                    / \     looking onto surface (in neg normal direction)
         | 
| 84 | 
            +
                  c1---c2
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                full_vertices = vertices[faces] #F,C=3,3
         | 
| 87 | 
            +
                v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
         | 
| 88 | 
            +
                face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
         | 
| 89 | 
            +
                if normalize:
         | 
| 90 | 
            +
                    face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) 
         | 
| 91 | 
            +
                return face_normals #F,3
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            def calc_vertex_normals(
         | 
| 94 | 
            +
                    vertices:torch.Tensor, #V,3 first vertex may be unreferenced
         | 
| 95 | 
            +
                    faces:torch.Tensor, #F,3 long, first face may be all zero
         | 
| 96 | 
            +
                    face_normals:torch.Tensor=None, #F,3, not normalized
         | 
| 97 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                F = faces.shape[0]
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                if face_normals is None:
         | 
| 102 | 
            +
                    face_normals = calc_face_normals(vertices,faces)
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
         | 
| 105 | 
            +
                vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
         | 
| 106 | 
            +
                vertex_normals = vertex_normals.sum(dim=1) #V,3
         | 
| 107 | 
            +
                return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            def calc_face_ref_normals(
         | 
| 110 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 111 | 
            +
                    vertex_normals:torch.Tensor, #V,3 first unused
         | 
| 112 | 
            +
                    normalize:bool=False,
         | 
| 113 | 
            +
                    )->torch.Tensor: #F,3
         | 
| 114 | 
            +
                """calculate reference normals for face flip detection"""
         | 
| 115 | 
            +
                full_normals = vertex_normals[faces] #F,C=3,3
         | 
| 116 | 
            +
                ref_normals = full_normals.sum(dim=1) #F,3
         | 
| 117 | 
            +
                if normalize:
         | 
| 118 | 
            +
                    ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
         | 
| 119 | 
            +
                return ref_normals
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            def pack(
         | 
| 122 | 
            +
                    vertices:torch.Tensor, #V,3 first unused and nan
         | 
| 123 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 124 | 
            +
                    )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
         | 
| 125 | 
            +
                """removes unused elements in vertices and faces"""
         | 
| 126 | 
            +
                V = vertices.shape[0]
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                # remove unused faces
         | 
| 129 | 
            +
                used_faces = faces[:,0]!=0
         | 
| 130 | 
            +
                used_faces[0] = True
         | 
| 131 | 
            +
                faces = faces[used_faces] #sync
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # remove unused vertices
         | 
| 134 | 
            +
                used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
         | 
| 135 | 
            +
                used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') 
         | 
| 136 | 
            +
                used_vertices = used_vertices.any(dim=1)
         | 
| 137 | 
            +
                used_vertices[0] = True
         | 
| 138 | 
            +
                vertices = vertices[used_vertices] #sync
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # update used faces
         | 
| 141 | 
            +
                ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
         | 
| 142 | 
            +
                V1 = used_vertices.sum()
         | 
| 143 | 
            +
                ind[used_vertices] =  torch.arange(0,V1,device=vertices.device) #sync
         | 
| 144 | 
            +
                faces = ind[faces]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                return vertices,faces
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            def split_edges(
         | 
| 149 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 150 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 151 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 152 | 
            +
                    face_to_edge:torch.Tensor, #F,3 long 0 for unused
         | 
| 153 | 
            +
                    splits, #E bool
         | 
| 154 | 
            +
                    pack_faces:bool=True,
         | 
| 155 | 
            +
                    )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                #   c2                    c2               c...corners = faces
         | 
| 158 | 
            +
                #    . .                   . .             s...side_vert, 0 means no split
         | 
| 159 | 
            +
                #    .   .                 .N2 .           S...shrunk_face
         | 
| 160 | 
            +
                #    .     .               .     .         Ni...new_faces
         | 
| 161 | 
            +
                #   s2      s1           s2|c2...s1|c1
         | 
| 162 | 
            +
                #    .        .            .     .  .
         | 
| 163 | 
            +
                #    .          .          . S .      .
         | 
| 164 | 
            +
                #    .            .        . .     N1    .
         | 
| 165 | 
            +
                #   c0...(s0=0)....c1    s0|c0...........c1
         | 
| 166 | 
            +
                #
         | 
| 167 | 
            +
                # pseudo-code:
         | 
| 168 | 
            +
                #   S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
         | 
| 169 | 
            +
                #   split = side_vert!=0 example:[False,True,True]
         | 
| 170 | 
            +
                #   N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
         | 
| 171 | 
            +
                #   N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
         | 
| 172 | 
            +
                #   N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                V = vertices.shape[0]
         | 
| 175 | 
            +
                F = faces.shape[0]
         | 
| 176 | 
            +
                S = splits.sum().item() #sync
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                if S==0:
         | 
| 179 | 
            +
                    return vertices,faces
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
         | 
| 182 | 
            +
                edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
         | 
| 183 | 
            +
                side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
         | 
| 184 | 
            +
                split_edges = edges[splits] #S sync
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                #vertices
         | 
| 187 | 
            +
                split_vertices = vertices[split_edges].mean(dim=1) #S,3
         | 
| 188 | 
            +
                vertices = torch.concat((vertices,split_vertices),dim=0)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                #faces
         | 
| 191 | 
            +
                side_split = side_vert!=0 #F,3
         | 
| 192 | 
            +
                shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
         | 
| 193 | 
            +
                new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
         | 
| 194 | 
            +
                faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
         | 
| 195 | 
            +
                if pack_faces:
         | 
| 196 | 
            +
                    mask = faces[:,0]!=0
         | 
| 197 | 
            +
                    mask[0] = True
         | 
| 198 | 
            +
                    faces = faces[mask] #F',3 sync
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                return vertices,faces
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            def collapse_edges(
         | 
| 203 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 204 | 
            +
                    faces:torch.Tensor, #F,3 long 0 for unused
         | 
| 205 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 206 | 
            +
                    priorities:torch.Tensor, #E float
         | 
| 207 | 
            +
                    stable:bool=False, #only for unit testing
         | 
| 208 | 
            +
                    )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
         | 
| 209 | 
            +
                    
         | 
| 210 | 
            +
                V = vertices.shape[0]
         | 
| 211 | 
            +
                
         | 
| 212 | 
            +
                # check spacing
         | 
| 213 | 
            +
                _,order = priorities.sort(stable=stable) #E
         | 
| 214 | 
            +
                rank = torch.zeros_like(order)
         | 
| 215 | 
            +
                rank[order] = torch.arange(0,len(rank),device=rank.device)
         | 
| 216 | 
            +
                vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
         | 
| 217 | 
            +
                edge_rank = rank #E
         | 
| 218 | 
            +
                for i in range(3):
         | 
| 219 | 
            +
                    torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
         | 
| 220 | 
            +
                    edge_rank,_ = vert_rank[edges].max(dim=-1) #E
         | 
| 221 | 
            +
                candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # check connectivity
         | 
| 224 | 
            +
                vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
         | 
| 225 | 
            +
                vert_connections[candidates[:,0]] = 1 #start
         | 
| 226 | 
            +
                edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
         | 
| 227 | 
            +
                vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
         | 
| 228 | 
            +
                vert_connections[candidates] = 0 #clear start and end
         | 
| 229 | 
            +
                edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
         | 
| 230 | 
            +
                vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
         | 
| 231 | 
            +
                collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                # mean vertices
         | 
| 234 | 
            +
                vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) 
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                # update faces
         | 
| 237 | 
            +
                dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
         | 
| 238 | 
            +
                dest[collapses[:,1]] = dest[collapses[:,0]]
         | 
| 239 | 
            +
                faces = dest[faces] #F,3 
         | 
| 240 | 
            +
                c0,c1,c2 = faces.unbind(dim=-1)
         | 
| 241 | 
            +
                collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
         | 
| 242 | 
            +
                faces[collapsed] = 0
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                return vertices,faces
         | 
| 245 | 
            +
             | 
| 246 | 
            +
            def calc_face_collapses(
         | 
| 247 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 248 | 
            +
                    faces:torch.Tensor, #F,3 long, 0 for unused
         | 
| 249 | 
            +
                    edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
         | 
| 250 | 
            +
                    face_to_edge:torch.Tensor, #F,3 long 0 for unused
         | 
| 251 | 
            +
                    edge_length:torch.Tensor, #E
         | 
| 252 | 
            +
                    face_normals:torch.Tensor, #F,3
         | 
| 253 | 
            +
                    vertex_normals:torch.Tensor, #V,3 first unused
         | 
| 254 | 
            +
                    min_edge_length:torch.Tensor=None, #V
         | 
| 255 | 
            +
                    area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
         | 
| 256 | 
            +
                    shortest_probability = 0.8
         | 
| 257 | 
            +
                    )->torch.Tensor: #E edges to collapse
         | 
| 258 | 
            +
                
         | 
| 259 | 
            +
                E = edges.shape[0]
         | 
| 260 | 
            +
                F = faces.shape[0]
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # face flips
         | 
| 263 | 
            +
                ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
         | 
| 264 | 
            +
                face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
         | 
| 265 | 
            +
                
         | 
| 266 | 
            +
                # small faces
         | 
| 267 | 
            +
                if min_edge_length is not None:
         | 
| 268 | 
            +
                    min_face_length = min_edge_length[faces].mean(dim=-1) #F
         | 
| 269 | 
            +
                    min_area = min_face_length**2 * area_ratio #F
         | 
| 270 | 
            +
                    face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
         | 
| 271 | 
            +
                    face_collapses[0] = False
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                # faces to edges
         | 
| 274 | 
            +
                face_length = edge_length[face_to_edge] #F,3
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if shortest_probability<1:
         | 
| 277 | 
            +
                    #select shortest edge with shortest_probability chance
         | 
| 278 | 
            +
                    randlim = round(2/(1-shortest_probability))
         | 
| 279 | 
            +
                    rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
         | 
| 280 | 
            +
                    sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
         | 
| 281 | 
            +
                    local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
         | 
| 282 | 
            +
                else:
         | 
| 283 | 
            +
                    local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
         | 
| 284 | 
            +
                
         | 
| 285 | 
            +
                edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
         | 
| 286 | 
            +
                edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
         | 
| 287 | 
            +
                edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) 
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                return edge_collapses.bool()
         | 
| 290 | 
            +
             | 
| 291 | 
            +
            def flip_edges(
         | 
| 292 | 
            +
                    vertices:torch.Tensor, #V,3 first unused
         | 
| 293 | 
            +
                    faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
         | 
| 294 | 
            +
                    edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
         | 
| 295 | 
            +
                    edge_to_face:torch.Tensor, #E,[left,right],[face,side]
         | 
| 296 | 
            +
                    with_border:bool=True, #handle border edges (D=4 instead of D=6)
         | 
| 297 | 
            +
                    with_normal_check:bool=True, #check face normal flips
         | 
| 298 | 
            +
                    stable:bool=False, #only for unit testing
         | 
| 299 | 
            +
                    ):
         | 
| 300 | 
            +
                V = vertices.shape[0]
         | 
| 301 | 
            +
                E = edges.shape[0]
         | 
| 302 | 
            +
                device=vertices.device
         | 
| 303 | 
            +
                vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
         | 
| 304 | 
            +
                vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
         | 
| 305 | 
            +
                neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
         | 
| 306 | 
            +
                neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
         | 
| 307 | 
            +
                edge_is_inside = neighbors.all(dim=-1) #E
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                if with_border:
         | 
| 310 | 
            +
                    # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
         | 
| 311 | 
            +
                    # need to use float for masks in order to use scatter(reduce='multiply')
         | 
| 312 | 
            +
                    vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
         | 
| 313 | 
            +
                    src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
         | 
| 314 | 
            +
                    vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
         | 
| 315 | 
            +
                    vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
         | 
| 316 | 
            +
                    vertex_degree -= 2 * vertex_is_inside #V long
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                neighbor_degrees = vertex_degree[neighbors] #E,LR=2
         | 
| 319 | 
            +
                edge_degrees = vertex_degree[edges] #E,2
         | 
| 320 | 
            +
                #
         | 
| 321 | 
            +
                # loss = Sum_over_affected_vertices((new_degree-6)**2)
         | 
| 322 | 
            +
                # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
         | 
| 323 | 
            +
                #                   + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
         | 
| 324 | 
            +
                #             = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
         | 
| 325 | 
            +
                #
         | 
| 326 | 
            +
                loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
         | 
| 327 | 
            +
                candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
         | 
| 328 | 
            +
                loss_change = loss_change[candidates] #E'
         | 
| 329 | 
            +
                if loss_change.shape[0]==0:
         | 
| 330 | 
            +
                    return
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
         | 
| 333 | 
            +
                _,order = loss_change.sort(descending=True, stable=stable) #E'
         | 
| 334 | 
            +
                rank = torch.zeros_like(order)
         | 
| 335 | 
            +
                rank[order] = torch.arange(0,len(rank),device=rank.device)
         | 
| 336 | 
            +
                vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
         | 
| 337 | 
            +
                torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
         | 
| 338 | 
            +
                vertex_rank,_ = vertex_rank.max(dim=-1) #V
         | 
| 339 | 
            +
                neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
         | 
| 340 | 
            +
                flip = rank==neighborhood_rank #E'
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                if with_normal_check:
         | 
| 343 | 
            +
                    #  cl-<-----e1     e0,e1...edge, e0<e1
         | 
| 344 | 
            +
                    #   |      /A      L,R....left and right face
         | 
| 345 | 
            +
                    #   |  L /  |      both triangles ordered counter clockwise
         | 
| 346 | 
            +
                    #   |  / R  |      normals pointing out of screen
         | 
| 347 | 
            +
                    #   V/      |      
         | 
| 348 | 
            +
                    #   e0---->-cr    
         | 
| 349 | 
            +
                    v = vertices[edges_neighbors] #E",4,3
         | 
| 350 | 
            +
                    v = v - v[:,0:1] #make relative to e0 
         | 
| 351 | 
            +
                    e1 = v[:,1]
         | 
| 352 | 
            +
                    cl = v[:,2]
         | 
| 353 | 
            +
                    cr = v[:,3]
         | 
| 354 | 
            +
                    n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors 
         | 
| 355 | 
            +
                    flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
         | 
| 356 | 
            +
                    flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                flip_edges_neighbors = edges_neighbors[flip] #E",4
         | 
| 359 | 
            +
                flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
         | 
| 360 | 
            +
                flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
         | 
| 361 | 
            +
                faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
         | 
    	
        mesh_reconstruction/render.py
    ADDED
    
    | @@ -0,0 +1,159 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # modified from https://github.com/Profactor/continuous-remeshing
         | 
| 2 | 
            +
            import nvdiffrast.torch as dr
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from typing import Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def _warmup(glctx, device=None):
         | 
| 7 | 
            +
                device = 'cuda' if device is None else device
         | 
| 8 | 
            +
                #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
         | 
| 9 | 
            +
                def tensor(*args, **kwargs):
         | 
| 10 | 
            +
                    return torch.tensor(*args, device=device, **kwargs)
         | 
| 11 | 
            +
                pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
         | 
| 12 | 
            +
                tri = tensor([[0, 1, 2]], dtype=torch.int32)
         | 
| 13 | 
            +
                dr.rasterize(glctx, pos, tri, resolution=[256, 256])
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class NormalsRenderer:
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
                _glctx:dr.RasterizeCudaContext = None
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                def __init__(
         | 
| 20 | 
            +
                        self,
         | 
| 21 | 
            +
                        mv: torch.Tensor, #C,4,4
         | 
| 22 | 
            +
                        proj: torch.Tensor, #C,4,4
         | 
| 23 | 
            +
                        image_size: Tuple[int,int],
         | 
| 24 | 
            +
                        mvp = None,
         | 
| 25 | 
            +
                        device=None,
         | 
| 26 | 
            +
                        ):
         | 
| 27 | 
            +
                    if mvp is None:
         | 
| 28 | 
            +
                        self._mvp = proj @ mv #C,4,4
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        self._mvp = mvp
         | 
| 31 | 
            +
                    self._image_size = image_size
         | 
| 32 | 
            +
                    self._glctx = dr.RasterizeCudaContext(device=device)
         | 
| 33 | 
            +
                    _warmup(self._glctx, device)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def render(self,
         | 
| 36 | 
            +
                        vertices: torch.Tensor, #V,3 float
         | 
| 37 | 
            +
                        normals: torch.Tensor, #V,3 float   in [-1, 1]
         | 
| 38 | 
            +
                        faces: torch.Tensor, #F,3 long
         | 
| 39 | 
            +
                        ) ->torch.Tensor: #C,H,W,4
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    V = vertices.shape[0]
         | 
| 42 | 
            +
                    faces = faces.type(torch.int32)
         | 
| 43 | 
            +
                    vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
         | 
| 44 | 
            +
                    vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
         | 
| 45 | 
            +
                    rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
         | 
| 46 | 
            +
                    vert_col = (normals+1)/2 #V,3
         | 
| 47 | 
            +
                    col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
         | 
| 48 | 
            +
                    alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
         | 
| 49 | 
            +
                    col = torch.concat((col,alpha),dim=-1) #C,H,W,4
         | 
| 50 | 
            +
                    col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
         | 
| 51 | 
            +
                    return col #C,H,W,4
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            from pytorch3d.structures import Meshes
         | 
| 54 | 
            +
            from pytorch3d.renderer.mesh.shader import ShaderBase
         | 
| 55 | 
            +
            from pytorch3d.renderer import (
         | 
| 56 | 
            +
                RasterizationSettings,
         | 
| 57 | 
            +
                MeshRendererWithFragments,
         | 
| 58 | 
            +
                TexturesVertex,
         | 
| 59 | 
            +
                MeshRasterizer,
         | 
| 60 | 
            +
                BlendParams,
         | 
| 61 | 
            +
                FoVOrthographicCameras,
         | 
| 62 | 
            +
                look_at_view_transform,
         | 
| 63 | 
            +
                hard_rgb_blend,
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            class VertexColorShader(ShaderBase):
         | 
| 67 | 
            +
                def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
         | 
| 68 | 
            +
                    blend_params = kwargs.get("blend_params", self.blend_params)
         | 
| 69 | 
            +
                    texels = meshes.sample_textures(fragments)
         | 
| 70 | 
            +
                    return hard_rgb_blend(texels, fragments, blend_params)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
         | 
| 73 | 
            +
                if len(mesh) != len(cameras):
         | 
| 74 | 
            +
                    if len(cameras) % len(mesh) == 0:
         | 
| 75 | 
            +
                        mesh = mesh.extend(len(cameras))
         | 
| 76 | 
            +
                    else:
         | 
| 77 | 
            +
                        raise NotImplementedError()
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                # render requires everything in float16 or float32
         | 
| 80 | 
            +
                input_dtype = dtype
         | 
| 81 | 
            +
                blend_params = BlendParams(1e-4, 1e-4, bkgd)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                # Define the settings for rasterization and shading
         | 
| 84 | 
            +
                raster_settings = RasterizationSettings(
         | 
| 85 | 
            +
                    image_size=(H, W),
         | 
| 86 | 
            +
                    blur_radius=blur_radius,
         | 
| 87 | 
            +
                    faces_per_pixel=faces_per_pixel,
         | 
| 88 | 
            +
                    clip_barycentric_coords=True,
         | 
| 89 | 
            +
                    bin_size=None,
         | 
| 90 | 
            +
                    max_faces_per_bin=500000,
         | 
| 91 | 
            +
                )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # Create a renderer by composing a rasterizer and a shader
         | 
| 94 | 
            +
                # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
         | 
| 95 | 
            +
                renderer = MeshRendererWithFragments(
         | 
| 96 | 
            +
                    rasterizer=MeshRasterizer(
         | 
| 97 | 
            +
                        cameras=cameras,
         | 
| 98 | 
            +
                        raster_settings=raster_settings
         | 
| 99 | 
            +
                    ),
         | 
| 100 | 
            +
                    shader=VertexColorShader(
         | 
| 101 | 
            +
                        device=device,
         | 
| 102 | 
            +
                        cameras=cameras,
         | 
| 103 | 
            +
                        blend_params=blend_params
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
                )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                # render RGB and depth, get mask
         | 
| 108 | 
            +
                with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
         | 
| 109 | 
            +
                    images, _ = renderer(mesh)
         | 
| 110 | 
            +
                return images   # BHW4
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            class Pytorch3DNormalsRenderer:
         | 
| 113 | 
            +
                def __init__(self, cameras, image_size, device):
         | 
| 114 | 
            +
                    self.cameras = cameras.to(device)
         | 
| 115 | 
            +
                    self._image_size = image_size
         | 
| 116 | 
            +
                    self.device = device
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                def render(self,
         | 
| 119 | 
            +
                        vertices: torch.Tensor, #V,3 float
         | 
| 120 | 
            +
                        normals: torch.Tensor, #V,3 float   in [-1, 1]
         | 
| 121 | 
            +
                        faces: torch.Tensor, #F,3 long
         | 
| 122 | 
            +
                        ) ->torch.Tensor: #C,H,W,4
         | 
| 123 | 
            +
                    mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
         | 
| 124 | 
            +
                    return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
            def save_tensor_to_img(tensor, save_dir):
         | 
| 127 | 
            +
                from PIL import Image
         | 
| 128 | 
            +
                import numpy as np
         | 
| 129 | 
            +
                for idx, img in enumerate(tensor):
         | 
| 130 | 
            +
                    img = img[..., :3].cpu().numpy()
         | 
| 131 | 
            +
                    img = (img * 255).astype(np.uint8)
         | 
| 132 | 
            +
                    img = Image.fromarray(img)
         | 
| 133 | 
            +
                    img.save(save_dir + f"{idx}.png")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            if __name__ == "__main__":
         | 
| 136 | 
            +
                import sys
         | 
| 137 | 
            +
                import os
         | 
| 138 | 
            +
                sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
         | 
| 139 | 
            +
                from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
         | 
| 140 | 
            +
                cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
         | 
| 141 | 
            +
                mv,proj = make_star_cameras_orthographic(4, 1)
         | 
| 142 | 
            +
                resolution = 1024
         | 
| 143 | 
            +
                renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
         | 
| 144 | 
            +
                renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
         | 
| 145 | 
            +
                vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
         | 
| 146 | 
            +
                normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
         | 
| 147 | 
            +
                faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
                import time
         | 
| 150 | 
            +
                t0 = time.time()
         | 
| 151 | 
            +
                r1 = renderer1.render(vertices, normals, faces)
         | 
| 152 | 
            +
                print("time r1:", time.time() - t0)
         | 
| 153 | 
            +
                
         | 
| 154 | 
            +
                t0 = time.time()
         | 
| 155 | 
            +
                r2 = renderer2.render(vertices, normals, faces)
         | 
| 156 | 
            +
                print("time r2:", time.time() - t0)
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                for i in range(4):
         | 
| 159 | 
            +
                    print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
         | 
    	
        package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ff4a35615ed42148c8579622bee6dca88f7f3be683671524a282fafaf7589682
         | 
| 3 | 
            +
            size 3079614
         | 
    	
        package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:11e7f7f781fef16c09ec8d03bfb6da84cf61c54fc59e8a4ea047a90c4a24e88f
         | 
| 3 | 
            +
            size 162720703
         | 

