Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +3 -0
 - .gitignore +4 -0
 - app.py +390 -0
 - assets/teaser.jpg +3 -0
 - configs/infer.yml +52 -0
 - data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply +0 -0
 - data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply +0 -0
 - data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply +0 -0
 - data/basic_shapes_norm/basic_shapes.json +89 -0
 - data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply +3 -0
 - data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply +3 -0
 - data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply +3 -0
 - data/demo_glb/barbell.glb +3 -0
 - data/demo_glb/book.glb +3 -0
 - data/demo_glb/bunny.glb +3 -0
 - data/demo_glb/desk.glb +3 -0
 - data/demo_glb/man.glb +3 -0
 - data/demo_glb/micky.glb +3 -0
 - data/demo_glb/pac.glb +3 -0
 - data/demo_glb/robot.glb +3 -0
 - data/demo_glb/rocket.glb +3 -0
 - data/demo_glb/sheep.glb +3 -0
 - data/demo_glb/shelf.glb +3 -0
 - data/demo_glb/table.glb +3 -0
 - data/demo_glb/vent.glb +3 -0
 - data/demo_glb/walkman.glb +3 -0
 - pre-requirements.txt +36 -0
 - primitive_anything/__init__.py +0 -0
 - primitive_anything/michelangelo/__init__.py +51 -0
 - primitive_anything/michelangelo/data/__init__.py +1 -0
 - primitive_anything/michelangelo/data/templates.json +69 -0
 - primitive_anything/michelangelo/data/transforms.py +407 -0
 - primitive_anything/michelangelo/data/utils.py +59 -0
 - primitive_anything/michelangelo/graphics/__init__.py +1 -0
 - primitive_anything/michelangelo/graphics/primitives/__init__.py +9 -0
 - primitive_anything/michelangelo/graphics/primitives/mesh.py +114 -0
 - primitive_anything/michelangelo/graphics/primitives/volume.py +21 -0
 - primitive_anything/michelangelo/models/__init__.py +1 -0
 - primitive_anything/michelangelo/models/asl_diffusion/__init__.py +1 -0
 - primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py +483 -0
 - primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py +104 -0
 - primitive_anything/michelangelo/models/asl_diffusion/base.py +13 -0
 - primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py +393 -0
 - primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py +80 -0
 - primitive_anything/michelangelo/models/conditional_encoders/__init__.py +3 -0
 - primitive_anything/michelangelo/models/conditional_encoders/clip.py +89 -0
 - primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py +562 -0
 - primitive_anything/michelangelo/models/modules/__init__.py +3 -0
 - primitive_anything/michelangelo/models/modules/checkpoint.py +69 -0
 - primitive_anything/michelangelo/models/modules/diffusion_transformer.py +218 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            data/demo_glb/*.glb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            assets/*.jpg filter=lfs diff=lfs merge=lfs -text
         
     | 
| 38 | 
         
            +
            data/basic_shapes_norm_pc10000/*.ply filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            **/__pycache__/
         
     | 
| 2 | 
         
            +
            ckpt
         
     | 
| 3 | 
         
            +
            gradio_cached_examples
         
     | 
| 4 | 
         
            +
            results
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,390 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import time
         
     | 
| 3 | 
         
            +
            import glob
         
     | 
| 4 | 
         
            +
            import json
         
     | 
| 5 | 
         
            +
            import yaml
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import trimesh
         
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
            import mesh2sdf.core
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import skimage.measure
         
     | 
| 12 | 
         
            +
            import seaborn as sns
         
     | 
| 13 | 
         
            +
            from scipy.spatial.transform import Rotation
         
     | 
| 14 | 
         
            +
            from mesh_to_sdf import get_surface_point_cloud
         
     | 
| 15 | 
         
            +
            from accelerate.utils import set_seed
         
     | 
| 16 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 17 | 
         
            +
            from huggingface_hub.file_download import hf_hub_download
         
     | 
| 18 | 
         
            +
            from huggingface_hub import list_repo_files
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from primitive_anything.utils import path_mkdir, count_parameters
         
     | 
| 21 | 
         
            +
            from primitive_anything.utils.logger import print_log
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            os.environ['PYOPENGL_PLATFORM'] = 'egl'
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            import spaces
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            repo_id = "hyz317/PrimitiveAnything"
         
     | 
| 28 | 
         
            +
            all_files = list_repo_files(repo_id, revision="main")
         
     | 
| 29 | 
         
            +
            for file in all_files:
         
     | 
| 30 | 
         
            +
                if os.path.exists(file):
         
     | 
| 31 | 
         
            +
                    continue
         
     | 
| 32 | 
         
            +
                hf_hub_download(repo_id, file, local_dir="./ckpt")
         
     | 
| 33 | 
         
            +
            hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt")
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def parse_args():
         
     | 
| 36 | 
         
            +
                parser = argparse.ArgumentParser(description='Process 3D model files')
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                parser.add_argument(
         
     | 
| 39 | 
         
            +
                    '--input',
         
     | 
| 40 | 
         
            +
                    type=str,
         
     | 
| 41 | 
         
            +
                    default='./data/demo_glb/',
         
     | 
| 42 | 
         
            +
                    help='Input file or directory path (default: ./data/demo_glb/)'
         
     | 
| 43 | 
         
            +
                )
         
     | 
| 44 | 
         
            +
                
         
     | 
| 45 | 
         
            +
                parser.add_argument(
         
     | 
| 46 | 
         
            +
                    '--log_path',
         
     | 
| 47 | 
         
            +
                    type=str,
         
     | 
| 48 | 
         
            +
                    default='./results/demo',
         
     | 
| 49 | 
         
            +
                    help='Output directory path (default: results/demo)'
         
     | 
| 50 | 
         
            +
                )
         
     | 
| 51 | 
         
            +
                
         
     | 
| 52 | 
         
            +
                return parser.parse_args()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def get_input_files(input_path):
         
     | 
| 55 | 
         
            +
                if os.path.isfile(input_path):
         
     | 
| 56 | 
         
            +
                    return [input_path]
         
     | 
| 57 | 
         
            +
                elif os.path.isdir(input_path):
         
     | 
| 58 | 
         
            +
                    return glob.glob(os.path.join(input_path, '*'))
         
     | 
| 59 | 
         
            +
                else:
         
     | 
| 60 | 
         
            +
                    raise ValueError(f"Input path {input_path} is neither a file nor a directory")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            args = parse_args()
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # Create output directory (keeping your original variable name)
         
     | 
| 65 | 
         
            +
            LOG_PATH = args.log_path
         
     | 
| 66 | 
         
            +
            os.makedirs(LOG_PATH, exist_ok=True)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            print(f"Output directory: {LOG_PATH}")
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            CODE_SHAPE = {
         
     | 
| 71 | 
         
            +
                0: 'SM_GR_BS_CubeBevel_001.ply',
         
     | 
| 72 | 
         
            +
                1: 'SM_GR_BS_SphereSharp_001.ply',
         
     | 
| 73 | 
         
            +
                2: 'SM_GR_BS_CylinderSharp_001.ply',
         
     | 
| 74 | 
         
            +
            }
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            shapename_map = {
         
     | 
| 77 | 
         
            +
                'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
         
     | 
| 78 | 
         
            +
                'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
         
     | 
| 79 | 
         
            +
                'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
         
     | 
| 80 | 
         
            +
            }
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            #### config
         
     | 
| 83 | 
         
            +
            bs_dir = 'data/basic_shapes_norm'
         
     | 
| 84 | 
         
            +
            config_path = './configs/infer.yml'
         
     | 
| 85 | 
         
            +
            AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt'
         
     | 
| 86 | 
         
            +
            temperature= 0.0
         
     | 
| 87 | 
         
            +
            #### init model
         
     | 
| 88 | 
         
            +
            mesh_bs = {}
         
     | 
| 89 | 
         
            +
            for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
         
     | 
| 90 | 
         
            +
                bs_name = os.path.basename(bs_path)
         
     | 
| 91 | 
         
            +
                bs = trimesh.load(bs_path)
         
     | 
| 92 | 
         
            +
                bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
         
     | 
| 93 | 
         
            +
                bs.visual = bs.visual.to_color()
         
     | 
| 94 | 
         
            +
                mesh_bs[bs_name] = bs
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def create_model(cfg_model):
         
     | 
| 97 | 
         
            +
                kwargs = cfg_model
         
     | 
| 98 | 
         
            +
                name = kwargs.pop('name')
         
     | 
| 99 | 
         
            +
                model = get_model(name)(**kwargs)
         
     | 
| 100 | 
         
            +
                print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
         
     | 
| 101 | 
         
            +
                return model
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
         
     | 
| 104 | 
         
            +
            def get_model(name):
         
     | 
| 105 | 
         
            +
                return {
         
     | 
| 106 | 
         
            +
                    'discrete': PrimitiveTransformerDiscrete,
         
     | 
| 107 | 
         
            +
                }[name]
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            with open(config_path, mode='r') as fp:
         
     | 
| 110 | 
         
            +
                AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            AR_checkpoint = torch.load(AR_checkpoint_path)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            transformer = create_model(AR_train_cfg['model'])
         
     | 
| 115 | 
         
            +
            transformer.load_state_dict(AR_checkpoint)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            device = torch.device('cuda')
         
     | 
| 118 | 
         
            +
            accelerator = Accelerator(
         
     | 
| 119 | 
         
            +
                mixed_precision='fp16',
         
     | 
| 120 | 
         
            +
            )
         
     | 
| 121 | 
         
            +
            transformer = accelerator.prepare(transformer)
         
     | 
| 122 | 
         
            +
            transformer.eval()
         
     | 
| 123 | 
         
            +
            transformer.bs_pc = transformer.bs_pc.cuda()
         
     | 
| 124 | 
         
            +
            transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda()
         
     | 
| 125 | 
         
            +
            print('model loaded to device')
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal',
         
     | 
| 129 | 
         
            +
                                      scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
         
     | 
| 130 | 
         
            +
                                      return_surface_pc_normals=False, normalized=False):
         
     | 
| 131 | 
         
            +
                sample_start = time.time()
         
     | 
| 132 | 
         
            +
                if surface_point_method == 'sample' and sign_method == 'depth':
         
     | 
| 133 | 
         
            +
                    print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
         
     | 
| 134 | 
         
            +
                    sign_method = 'normal'
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                surface_start = time.time()
         
     | 
| 137 | 
         
            +
                bound_radius = 1 if normalized else None
         
     | 
| 138 | 
         
            +
                surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
         
     | 
| 139 | 
         
            +
                                                              sample_point_count,
         
     | 
| 140 | 
         
            +
                                                              calculate_normals=sign_method == 'normal' or return_gradients)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                surface_end = time.time()
         
     | 
| 143 | 
         
            +
                print('surface point cloud time cost :', surface_end - surface_start)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                normal_start = time.time()
         
     | 
| 146 | 
         
            +
                if return_surface_pc_normals:
         
     | 
| 147 | 
         
            +
                    rng = np.random.default_rng()
         
     | 
| 148 | 
         
            +
                    assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
         
     | 
| 149 | 
         
            +
                    indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
         
     | 
| 150 | 
         
            +
                    points = surface_point_cloud.points[indices]
         
     | 
| 151 | 
         
            +
                    normals = surface_point_cloud.normals[indices]
         
     | 
| 152 | 
         
            +
                    surface_points = np.concatenate([points, normals], axis=-1)
         
     | 
| 153 | 
         
            +
                else:
         
     | 
| 154 | 
         
            +
                    surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
         
     | 
| 155 | 
         
            +
                normal_end = time.time()
         
     | 
| 156 | 
         
            +
                print('normal time cost :', normal_end - normal_start)
         
     | 
| 157 | 
         
            +
                sample_end = time.time()
         
     | 
| 158 | 
         
            +
                print('sample surface point time cost :', sample_end - sample_start)
         
     | 
| 159 | 
         
            +
                return surface_points
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            def normalize_vertices(vertices, scale=0.9):
         
     | 
| 163 | 
         
            +
                bbmin, bbmax = vertices.min(0), vertices.max(0)
         
     | 
| 164 | 
         
            +
                center = (bbmin + bbmax) * 0.5
         
     | 
| 165 | 
         
            +
                scale = 2.0 * scale / (bbmax - bbmin).max()
         
     | 
| 166 | 
         
            +
                vertices = (vertices - center) * scale
         
     | 
| 167 | 
         
            +
                return vertices, center, scale
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            def export_to_watertight(normalized_mesh, octree_depth: int = 7):
         
     | 
| 171 | 
         
            +
                """
         
     | 
| 172 | 
         
            +
                    Convert the non-watertight mesh to watertight.
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    Args:
         
     | 
| 175 | 
         
            +
                        input_path (str): normalized path
         
     | 
| 176 | 
         
            +
                        octree_depth (int):
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    Returns:
         
     | 
| 179 | 
         
            +
                        mesh(trimesh.Trimesh): watertight mesh
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    """
         
     | 
| 182 | 
         
            +
                size = 2 ** octree_depth
         
     | 
| 183 | 
         
            +
                level = 2 / size
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
         
     | 
| 186 | 
         
            +
                sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
         
     | 
| 187 | 
         
            +
                vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                # watertight mesh
         
     | 
| 190 | 
         
            +
                vertices = vertices / size * 2 - 1 # -1 to 1
         
     | 
| 191 | 
         
            +
                vertices = vertices / to_orig_scale + to_orig_center
         
     | 
| 192 | 
         
            +
                mesh = trimesh.Trimesh(vertices, faces, normals=normals)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                return mesh
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000):
         
     | 
| 198 | 
         
            +
                # mesh_list : list of trimesh
         
     | 
| 199 | 
         
            +
                pc_normal_list = []
         
     | 
| 200 | 
         
            +
                return_mesh_list = []
         
     | 
| 201 | 
         
            +
                for mesh in mesh_list:
         
     | 
| 202 | 
         
            +
                    if marching_cubes:
         
     | 
| 203 | 
         
            +
                        mesh = export_to_watertight(mesh)
         
     | 
| 204 | 
         
            +
                        print("MC over!")
         
     | 
| 205 | 
         
            +
                    if dilated_offset > 0:
         
     | 
| 206 | 
         
            +
                        new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset
         
     | 
| 207 | 
         
            +
                        mesh.vertices = new_vertices
         
     | 
| 208 | 
         
            +
                        print("dilate over!")
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    mesh.merge_vertices()
         
     | 
| 211 | 
         
            +
                    mesh.update_faces(mesh.unique_faces())
         
     | 
| 212 | 
         
            +
                    mesh.fix_normals()
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    return_mesh_list.append(mesh)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True))
         
     | 
| 217 | 
         
            +
                    pc_normal_list.append(pc_normal)
         
     | 
| 218 | 
         
            +
                    print("process mesh success")
         
     | 
| 219 | 
         
            +
                return pc_normal_list, return_mesh_list
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            ####    utils
         
     | 
| 223 | 
         
            +
            def euler_to_quat(euler):
         
     | 
| 224 | 
         
            +
                return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            def SRT_quat_to_matrix(scale, quat, translation):
         
     | 
| 227 | 
         
            +
                rotation_matrix = Rotation.from_quat(quat).as_matrix()
         
     | 
| 228 | 
         
            +
                transform_matrix = np.eye(4)
         
     | 
| 229 | 
         
            +
                transform_matrix[:3, :3] = rotation_matrix * scale
         
     | 
| 230 | 
         
            +
                transform_matrix[:3, 3] = translation
         
     | 
| 231 | 
         
            +
                return transform_matrix
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            def write_output(primitives, name):
         
     | 
| 235 | 
         
            +
                out_json = {}
         
     | 
| 236 | 
         
            +
                out_json['operation'] = 0
         
     | 
| 237 | 
         
            +
                out_json['type'] = 1
         
     | 
| 238 | 
         
            +
                out_json['scene_id'] = None
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                new_group = []
         
     | 
| 241 | 
         
            +
                model_scene = trimesh.Scene()
         
     | 
| 242 | 
         
            +
                color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
         
     | 
| 243 | 
         
            +
                color_map = (np.array(color_map) * 255).astype("uint8")
         
     | 
| 244 | 
         
            +
                for idx, (scale, rotation, translation, type_code) in enumerate(zip(
         
     | 
| 245 | 
         
            +
                    primitives['scale'].squeeze().cpu().numpy(),
         
     | 
| 246 | 
         
            +
                    primitives['rotation'].squeeze().cpu().numpy(),
         
     | 
| 247 | 
         
            +
                    primitives['translation'].squeeze().cpu().numpy(),
         
     | 
| 248 | 
         
            +
                    primitives['type_code'].squeeze().cpu().numpy()
         
     | 
| 249 | 
         
            +
                )):
         
     | 
| 250 | 
         
            +
                    if type_code == -1:
         
     | 
| 251 | 
         
            +
                        break
         
     | 
| 252 | 
         
            +
                    bs_name = CODE_SHAPE[type_code]
         
     | 
| 253 | 
         
            +
                    new_block = {}
         
     | 
| 254 | 
         
            +
                    new_block['type_id'] = shapename_map[bs_name]
         
     | 
| 255 | 
         
            +
                    new_block['data'] = {}
         
     | 
| 256 | 
         
            +
                    new_block['data']['location'] = translation.tolist()
         
     | 
| 257 | 
         
            +
                    new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
         
     | 
| 258 | 
         
            +
                    new_block['data']['scale'] = scale.tolist()
         
     | 
| 259 | 
         
            +
                    new_block['data']['color'] = ['808080']
         
     | 
| 260 | 
         
            +
                    new_group.append(new_block)
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
         
     | 
| 263 | 
         
            +
                    bs = mesh_bs[bs_name].copy().apply_transform(trans)
         
     | 
| 264 | 
         
            +
                    new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0)
         
     | 
| 265 | 
         
            +
                    bs.visual.vertex_colors[:, :3] = new_vertex_colors
         
     | 
| 266 | 
         
            +
                    vertices = bs.vertices.copy()
         
     | 
| 267 | 
         
            +
                    vertices[:, 1] = bs.vertices[:, 2]
         
     | 
| 268 | 
         
            +
                    vertices[:, 2] = -bs.vertices[:, 1]
         
     | 
| 269 | 
         
            +
                    bs.vertices = vertices
         
     | 
| 270 | 
         
            +
                    model_scene.add_geometry(bs)
         
     | 
| 271 | 
         
            +
                out_json['group'] = new_group
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                json_path = os.path.join(LOG_PATH, f'output_{name}.json')
         
     | 
| 274 | 
         
            +
                with open(json_path, 'w') as json_file:
         
     | 
| 275 | 
         
            +
                    json.dump(out_json, json_file, indent=4)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
         
     | 
| 278 | 
         
            +
                model_scene.export(glb_path)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                return glb_path, out_json
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
            @torch.no_grad()
         
     | 
| 284 | 
         
            +
            def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'):
         
     | 
| 285 | 
         
            +
                t1 = time.time()
         
     | 
| 286 | 
         
            +
                set_seed(sample_seed)
         
     | 
| 287 | 
         
            +
                input_mesh = trimesh.load(input_3d, force='mesh')
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                # scale mesh
         
     | 
| 290 | 
         
            +
                vertices = input_mesh.vertices
         
     | 
| 291 | 
         
            +
                bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
         
     | 
| 292 | 
         
            +
                vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
         
     | 
| 293 | 
         
            +
                vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
         
     | 
| 294 | 
         
            +
                input_mesh.vertices = vertices
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                pc_list, mesh_list = process_mesh_to_surface_pc(
         
     | 
| 297 | 
         
            +
                    [input_mesh],
         
     | 
| 298 | 
         
            +
                    marching_cubes=do_marching_cubes,
         
     | 
| 299 | 
         
            +
                    dilated_offset=dilated_offset
         
     | 
| 300 | 
         
            +
                )
         
     | 
| 301 | 
         
            +
                pc_normal = pc_list[0] # 10000, 6
         
     | 
| 302 | 
         
            +
                mesh = mesh_list[0]
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                pc_coor = pc_normal[:, :3]
         
     | 
| 305 | 
         
            +
                normals = pc_normal[:, 3:]
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                if dilated_offset > 0:
         
     | 
| 308 | 
         
            +
                    # scale mesh and pc
         
     | 
| 309 | 
         
            +
                    vertices = mesh.vertices
         
     | 
| 310 | 
         
            +
                    bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
         
     | 
| 311 | 
         
            +
                    vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
         
     | 
| 312 | 
         
            +
                    vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
         
     | 
| 313 | 
         
            +
                    mesh.vertices = vertices
         
     | 
| 314 | 
         
            +
                    pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
         
     | 
| 315 | 
         
            +
                    pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}')
         
     | 
| 318 | 
         
            +
                mesh.export(input_save_name)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong'
         
     | 
| 321 | 
         
            +
                normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                with accelerator.autocast():
         
     | 
| 326 | 
         
            +
                    if postprocess == 'postprocess1':
         
     | 
| 327 | 
         
            +
                        recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True)
         
     | 
| 328 | 
         
            +
                    else:
         
     | 
| 329 | 
         
            +
                        recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4])
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                return input_save_name, output_glb, output_json
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            import gradio as gr
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
            @spaces.GPU
         
     | 
| 339 | 
         
            +
            def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"):
         
     | 
| 340 | 
         
            +
                print(f"processing: {input_3d}")
         
     | 
| 341 | 
         
            +
                # try:
         
     | 
| 342 | 
         
            +
                preprocess_model_obj, output_model_obj, output_model_json = do_inference(
         
     | 
| 343 | 
         
            +
                    input_3d,
         
     | 
| 344 | 
         
            +
                    dilated_offset=dilated_offset,
         
     | 
| 345 | 
         
            +
                    do_marching_cubes=do_marching_cubes,
         
     | 
| 346 | 
         
            +
                    postprocess=postprocess_method
         
     | 
| 347 | 
         
            +
                )
         
     | 
| 348 | 
         
            +
                return output_model_obj
         
     | 
| 349 | 
         
            +
                # except Exception as e:
         
     | 
| 350 | 
         
            +
                #     return f"Error processing file: {str(e)}"
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
            # Title and reminder placeholders
         
     | 
| 353 | 
         
            +
            title = "3D Model Processing Demo"
         
     | 
| 354 | 
         
            +
            reminder = "Please upload your 3D model file and adjust parameters as needed."
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
            with gr.Blocks(title=title) as demo:
         
     | 
| 357 | 
         
            +
                # Title section
         
     | 
| 358 | 
         
            +
                gr.Markdown(f"# {title}")
         
     | 
| 359 | 
         
            +
                gr.Markdown(reminder)
         
     | 
| 360 | 
         
            +
                
         
     | 
| 361 | 
         
            +
                with gr.Row():
         
     | 
| 362 | 
         
            +
                    with gr.Column():
         
     | 
| 363 | 
         
            +
                        # Input components
         
     | 
| 364 | 
         
            +
                        input_3d = gr.Model3D(label="Upload 3D Model File")
         
     | 
| 365 | 
         
            +
                        dilated_offset = gr.Number(label="Dilated Offset", value=0.015)
         
     | 
| 366 | 
         
            +
                        do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True)
         
     | 
| 367 | 
         
            +
                        submit_btn = gr.Button("Process Model")
         
     | 
| 368 | 
         
            +
                        
         
     | 
| 369 | 
         
            +
                    with gr.Column():
         
     | 
| 370 | 
         
            +
                        # Output components
         
     | 
| 371 | 
         
            +
                        output = gr.Model3D(label="Primitive Assembly Predition")
         
     | 
| 372 | 
         
            +
                        
         
     | 
| 373 | 
         
            +
                submit_btn.click(
         
     | 
| 374 | 
         
            +
                    fn=process_3d_model,
         
     | 
| 375 | 
         
            +
                    inputs=[input_3d, dilated_offset, do_marching_cubes],
         
     | 
| 376 | 
         
            +
                    outputs=output
         
     | 
| 377 | 
         
            +
                )
         
     | 
| 378 | 
         
            +
                
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                # Prepare examples properly
         
     | 
| 381 | 
         
            +
                example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ]  # Note: wrapped in list and filtered for GLB
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                example = gr.Examples(
         
     | 
| 384 | 
         
            +
                    examples=example_files,
         
     | 
| 385 | 
         
            +
                    inputs=[input_3d],  # Only include the Model3D input
         
     | 
| 386 | 
         
            +
                    examples_per_page=14,
         
     | 
| 387 | 
         
            +
                )
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 390 | 
         
            +
                demo.launch()
         
     | 
    	
        assets/teaser.jpg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        configs/infer.yml
    ADDED
    
    | 
         @@ -0,0 +1,52 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            dataset:
         
     | 
| 2 | 
         
            +
              name: base
         
     | 
| 3 | 
         
            +
              pc_dir: ./data/test_pc
         
     | 
| 4 | 
         
            +
              bs_dir: data/basic_shapes_norm
         
     | 
| 5 | 
         
            +
              max_length: 144
         
     | 
| 6 | 
         
            +
              range_scale: [0, 1]
         
     | 
| 7 | 
         
            +
              range_rotation: [-180, 180]
         
     | 
| 8 | 
         
            +
              range_translation: [-1, 1]
         
     | 
| 9 | 
         
            +
              rotation_type: euler
         
     | 
| 10 | 
         
            +
              pc_format: pn
         
     | 
| 11 | 
         
            +
            model:
         
     | 
| 12 | 
         
            +
              attn_depth: 6
         
     | 
| 13 | 
         
            +
              attn_heads: 6
         
     | 
| 14 | 
         
            +
              bin_smooth_blur_sigma: -1
         
     | 
| 15 | 
         
            +
              bs_pc_dir: data/basic_shapes_norm_pc10000
         
     | 
| 16 | 
         
            +
              coarse_pre_gateloop_depth: 3
         
     | 
| 17 | 
         
            +
              continuous_range_rotation:
         
     | 
| 18 | 
         
            +
              - -181
         
     | 
| 19 | 
         
            +
              - 181
         
     | 
| 20 | 
         
            +
              continuous_range_scale:
         
     | 
| 21 | 
         
            +
              - 0
         
     | 
| 22 | 
         
            +
              - 1
         
     | 
| 23 | 
         
            +
              continuous_range_translation:
         
     | 
| 24 | 
         
            +
              - -1
         
     | 
| 25 | 
         
            +
              - 1
         
     | 
| 26 | 
         
            +
              dim: 768
         
     | 
| 27 | 
         
            +
              dim_rotation_embed: 16
         
     | 
| 28 | 
         
            +
              dim_scale_embed: 16
         
     | 
| 29 | 
         
            +
              dim_translation_embed: 16
         
     | 
| 30 | 
         
            +
              dim_type_embed: 48
         
     | 
| 31 | 
         
            +
              dropout: 0.0
         
     | 
| 32 | 
         
            +
              embed_order: ctrs
         
     | 
| 33 | 
         
            +
              gateloop_use_heinsen: false
         
     | 
| 34 | 
         
            +
              loss_weight:
         
     | 
| 35 | 
         
            +
                eos: 1.0
         
     | 
| 36 | 
         
            +
                reconstruction: 1.0
         
     | 
| 37 | 
         
            +
                rotation: 1.0
         
     | 
| 38 | 
         
            +
                scale: 1.0
         
     | 
| 39 | 
         
            +
                translation: 1.0
         
     | 
| 40 | 
         
            +
                type: 1.0
         
     | 
| 41 | 
         
            +
              max_primitive_len: 144
         
     | 
| 42 | 
         
            +
              name: discrete
         
     | 
| 43 | 
         
            +
              num_discrete_rotation: 181
         
     | 
| 44 | 
         
            +
              num_discrete_scale: 128
         
     | 
| 45 | 
         
            +
              num_discrete_translation: 128
         
     | 
| 46 | 
         
            +
              num_type: 3
         
     | 
| 47 | 
         
            +
              shape_cond_with_cat: true
         
     | 
| 48 | 
         
            +
              shape_cond_with_cross_attn: false
         
     | 
| 49 | 
         
            +
              shape_cond_with_film: false
         
     | 
| 50 | 
         
            +
              shape_condition_dim: 768
         
     | 
| 51 | 
         
            +
              shape_condition_len: 77
         
     | 
| 52 | 
         
            +
              shape_condition_model_type: michelangelo
         
     | 
    	
        data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply
    ADDED
    
    | 
         Binary file (10.1 kB). View file 
     | 
| 
         | 
    	
        data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply
    ADDED
    
    | 
         Binary file (4.96 kB). View file 
     | 
| 
         | 
    	
        data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply
    ADDED
    
    | 
         Binary file (27.5 kB). View file 
     | 
| 
         | 
    	
        data/basic_shapes_norm/basic_shapes.json
    ADDED
    
    | 
         @@ -0,0 +1,89 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
                "SM_GR_BS_CubeBevel_001.ply": {
         
     | 
| 3 | 
         
            +
                    "name": "SM_GR_BS_CubeBevel_001.ply",
         
     | 
| 4 | 
         
            +
                    "tform_bs_to_normalized": [
         
     | 
| 5 | 
         
            +
                        [
         
     | 
| 6 | 
         
            +
                            0.02,
         
     | 
| 7 | 
         
            +
                            0.0,
         
     | 
| 8 | 
         
            +
                            0.0,
         
     | 
| 9 | 
         
            +
                            0.0
         
     | 
| 10 | 
         
            +
                        ],
         
     | 
| 11 | 
         
            +
                        [
         
     | 
| 12 | 
         
            +
                            0.0,
         
     | 
| 13 | 
         
            +
                            0.02,
         
     | 
| 14 | 
         
            +
                            0.0,
         
     | 
| 15 | 
         
            +
                            9.701276818911235e-18
         
     | 
| 16 | 
         
            +
                        ],
         
     | 
| 17 | 
         
            +
                        [
         
     | 
| 18 | 
         
            +
                            0.0,
         
     | 
| 19 | 
         
            +
                            0.0,
         
     | 
| 20 | 
         
            +
                            0.019999999999999997,
         
     | 
| 21 | 
         
            +
                            -0.9999999999999999
         
     | 
| 22 | 
         
            +
                        ],
         
     | 
| 23 | 
         
            +
                        [
         
     | 
| 24 | 
         
            +
                            0.0,
         
     | 
| 25 | 
         
            +
                            0.0,
         
     | 
| 26 | 
         
            +
                            0.0,
         
     | 
| 27 | 
         
            +
                            1.0
         
     | 
| 28 | 
         
            +
                        ]
         
     | 
| 29 | 
         
            +
                    ]
         
     | 
| 30 | 
         
            +
                },
         
     | 
| 31 | 
         
            +
                "SM_GR_BS_CylinderSharp_001.ply": {
         
     | 
| 32 | 
         
            +
                    "name": "SM_GR_BS_CylinderSharp_001.ply",
         
     | 
| 33 | 
         
            +
                    "tform_bs_to_normalized": [
         
     | 
| 34 | 
         
            +
                        [
         
     | 
| 35 | 
         
            +
                            0.006666668023003748,
         
     | 
| 36 | 
         
            +
                            0.0,
         
     | 
| 37 | 
         
            +
                            0.0,
         
     | 
| 38 | 
         
            +
                            -2.0345056221459462e-07
         
     | 
| 39 | 
         
            +
                        ],
         
     | 
| 40 | 
         
            +
                        [
         
     | 
| 41 | 
         
            +
                            0.0,
         
     | 
| 42 | 
         
            +
                            0.006666667683919426,
         
     | 
| 43 | 
         
            +
                            0.0,
         
     | 
| 44 | 
         
            +
                            -5.086263794939386e-08
         
     | 
| 45 | 
         
            +
                        ],
         
     | 
| 46 | 
         
            +
                        [
         
     | 
| 47 | 
         
            +
                            0.0,
         
     | 
| 48 | 
         
            +
                            0.0,
         
     | 
| 49 | 
         
            +
                            0.006666665445429783,
         
     | 
| 50 | 
         
            +
                            -0.9999998370794186
         
     | 
| 51 | 
         
            +
                        ],
         
     | 
| 52 | 
         
            +
                        [
         
     | 
| 53 | 
         
            +
                            0.0,
         
     | 
| 54 | 
         
            +
                            0.0,
         
     | 
| 55 | 
         
            +
                            0.0,
         
     | 
| 56 | 
         
            +
                            1.0
         
     | 
| 57 | 
         
            +
                        ]
         
     | 
| 58 | 
         
            +
                    ]
         
     | 
| 59 | 
         
            +
                },
         
     | 
| 60 | 
         
            +
                "SM_GR_BS_SphereSharp_001.ply": {
         
     | 
| 61 | 
         
            +
                    "name": "SM_GR_BS_SphereSharp_001.ply",
         
     | 
| 62 | 
         
            +
                    "tform_bs_to_normalized": [
         
     | 
| 63 | 
         
            +
                        [
         
     | 
| 64 | 
         
            +
                            0.006666666666666667,
         
     | 
| 65 | 
         
            +
                            0.0,
         
     | 
| 66 | 
         
            +
                            0.0,
         
     | 
| 67 | 
         
            +
                            0.0
         
     | 
| 68 | 
         
            +
                        ],
         
     | 
| 69 | 
         
            +
                        [
         
     | 
| 70 | 
         
            +
                            0.0,
         
     | 
| 71 | 
         
            +
                            0.006666666666666667,
         
     | 
| 72 | 
         
            +
                            0.0,
         
     | 
| 73 | 
         
            +
                            0.0
         
     | 
| 74 | 
         
            +
                        ],
         
     | 
| 75 | 
         
            +
                        [
         
     | 
| 76 | 
         
            +
                            0.0,
         
     | 
| 77 | 
         
            +
                            0.0,
         
     | 
| 78 | 
         
            +
                            0.006666666666666667,
         
     | 
| 79 | 
         
            +
                            -1.0
         
     | 
| 80 | 
         
            +
                        ],
         
     | 
| 81 | 
         
            +
                        [
         
     | 
| 82 | 
         
            +
                            0.0,
         
     | 
| 83 | 
         
            +
                            0.0,
         
     | 
| 84 | 
         
            +
                            0.0,
         
     | 
| 85 | 
         
            +
                            1.0
         
     | 
| 86 | 
         
            +
                        ]
         
     | 
| 87 | 
         
            +
                    ]
         
     | 
| 88 | 
         
            +
                }
         
     | 
| 89 | 
         
            +
            }
         
     | 
    	
        data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ba980c1fb389e30783f09b07d35e788e08a97776d933b8bfd346147c9a7e86a0
         
     | 
| 3 | 
         
            +
            size 510265
         
     | 
    	
        data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ab8fb7aa7ec39237474d0a6e77da1d7070742f61af21e6b44dc9998fac1913cc
         
     | 
| 3 | 
         
            +
            size 510265
         
     | 
    	
        data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:8765da7294292422d077267c1b71b9ea055f831aab3840d869656632ee6e8569
         
     | 
| 3 | 
         
            +
            size 510265
         
     | 
    	
        data/demo_glb/barbell.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:8a9b9c124c321d6d18342b12407fc7327bdd56c8720d317e7b8694c10c851936
         
     | 
| 3 | 
         
            +
            size 769528
         
     | 
    	
        data/demo_glb/book.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:be526e9bca2ce3a74387f2dde7f6a25c9502a7c6d4f9fc671b244d09c18a9d94
         
     | 
| 3 | 
         
            +
            size 5369916
         
     | 
    	
        data/demo_glb/bunny.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:f62b2169b7cda3662de660d0b2e8ce2a1ccfe3bd186f243d890440e8cf7a0766
         
     | 
| 3 | 
         
            +
            size 27518016
         
     | 
    	
        data/demo_glb/desk.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:4c8fd8041f1e870ba285572f3fb3e129107678ab9a311524a8376cc404cc332e
         
     | 
| 3 | 
         
            +
            size 33679548
         
     | 
    	
        data/demo_glb/man.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:063a56d0a56d3866cf36170bbafad93924fd350882ac0f69e727cb43dc203351
         
     | 
| 3 | 
         
            +
            size 31784
         
     | 
    	
        data/demo_glb/micky.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:fc6410c2f2a588c5b064f9255a1c1657a9dff061ae6f7342df693c80eef0c69d
         
     | 
| 3 | 
         
            +
            size 294576
         
     | 
    	
        data/demo_glb/pac.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:dc39816cc71440fbc31d99c24f24ed42c25daf7319ba07b8dd3e34c1ea083578
         
     | 
| 3 | 
         
            +
            size 274004
         
     | 
    	
        data/demo_glb/robot.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:cf201fbe21d73428f88e4a7e428849148ebd500cc4ce6ac3929638a53c5376ae
         
     | 
| 3 | 
         
            +
            size 28116940
         
     | 
    	
        data/demo_glb/rocket.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:29601d218e0d51a4dced2f1ed2a898e80f0d223b1e04212d45b0dda4ad670d1c
         
     | 
| 3 | 
         
            +
            size 1426588
         
     | 
    	
        data/demo_glb/sheep.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:232b1303e56ec1682536c72bc9409585930985492dcbdfa101cdfb96d0b4fbf2
         
     | 
| 3 | 
         
            +
            size 28732
         
     | 
    	
        data/demo_glb/shelf.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:d63916c1ef1d5b2fc4d56e20c76316bd977f93819323f73e7f5e1c59df21e284
         
     | 
| 3 | 
         
            +
            size 3091336
         
     | 
    	
        data/demo_glb/table.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:743c96d7aa1bef88576c5f28d5e06144f93e27a2e6ea5ef8bd85669d1213af9f
         
     | 
| 3 | 
         
            +
            size 20093692
         
     | 
    	
        data/demo_glb/vent.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:98d8d0d2c8d164fc75361d7d1408e9102f9148482c0343e0cc87d21950e20ab1
         
     | 
| 3 | 
         
            +
            size 1785468
         
     | 
    	
        data/demo_glb/walkman.glb
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:992355cf45609881223561d1081a05483c2bad488ed7148f87243259aa36be1b
         
     | 
| 3 | 
         
            +
            size 158156
         
     | 
    	
        pre-requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            --extra-index-url https://download.pytorch.org/whl/cu121
         
     | 
| 2 | 
         
            +
            --extra-index-url https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html
         
     | 
| 3 | 
         
            +
            torch==2.2.0
         
     | 
| 4 | 
         
            +
            torchvision==0.17.0
         
     | 
| 5 | 
         
            +
            dgl
         
     | 
| 6 | 
         
            +
            accelerate
         
     | 
| 7 | 
         
            +
            beartype
         
     | 
| 8 | 
         
            +
            einops
         
     | 
| 9 | 
         
            +
            gateloop_transformer
         
     | 
| 10 | 
         
            +
            matplotlib
         
     | 
| 11 | 
         
            +
            scikit-learn
         
     | 
| 12 | 
         
            +
            pandas
         
     | 
| 13 | 
         
            +
            pytorch_custom_utils
         
     | 
| 14 | 
         
            +
            gradio
         
     | 
| 15 | 
         
            +
            pydantic==2.10.6
         
     | 
| 16 | 
         
            +
            x_transformers
         
     | 
| 17 | 
         
            +
            torch_redstone
         
     | 
| 18 | 
         
            +
            torchdata==0.9.0
         
     | 
| 19 | 
         
            +
            toolz
         
     | 
| 20 | 
         
            +
            environs
         
     | 
| 21 | 
         
            +
            jaxtyping
         
     | 
| 22 | 
         
            +
            omegaconf
         
     | 
| 23 | 
         
            +
            ema_pytorch
         
     | 
| 24 | 
         
            +
            local_attention==1.9.15
         
     | 
| 25 | 
         
            +
            taylor_series_linear_attention
         
     | 
| 26 | 
         
            +
            transformers
         
     | 
| 27 | 
         
            +
            vector_quantize_pytorch
         
     | 
| 28 | 
         
            +
            open3d
         
     | 
| 29 | 
         
            +
            trimesh
         
     | 
| 30 | 
         
            +
            pytorch_lightning
         
     | 
| 31 | 
         
            +
            scikit-image
         
     | 
| 32 | 
         
            +
            opencv-python
         
     | 
| 33 | 
         
            +
            mesh2sdf
         
     | 
| 34 | 
         
            +
            seaborn
         
     | 
| 35 | 
         
            +
            mesh_to_sdf
         
     | 
| 36 | 
         
            +
            point_cloud_utils
         
     | 
    	
        primitive_anything/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        primitive_anything/michelangelo/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,51 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torch import nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .utils.misc import instantiate_from_config
         
     | 
| 8 | 
         
            +
            from ..utils import default, exists
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def load_model():
         
     | 
| 12 | 
         
            +
                model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
         
     | 
| 13 | 
         
            +
                # print(model_config)
         
     | 
| 14 | 
         
            +
                if hasattr(model_config, "model"):
         
     | 
| 15 | 
         
            +
                    model_config = model_config.model
         
     | 
| 16 | 
         
            +
                ckpt_path = "./ckpt/shapevae-256.ckpt"
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
         
     | 
| 19 | 
         
            +
                # model = model.cuda()
         
     | 
| 20 | 
         
            +
                model = model.eval()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                return model
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class ShapeConditioner(nn.Module):
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    *,
         
     | 
| 29 | 
         
            +
                    dim_latent = None
         
     | 
| 30 | 
         
            +
                ):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
                    self.model = load_model()
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.dim_model_out = 768
         
     | 
| 35 | 
         
            +
                    dim_latent = default(dim_latent, self.dim_model_out)
         
     | 
| 36 | 
         
            +
                    self.dim_latent = dim_latent
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def forward(
         
     | 
| 39 | 
         
            +
                    self,
         
     | 
| 40 | 
         
            +
                    shape = None,
         
     | 
| 41 | 
         
            +
                    shape_embed = None,
         
     | 
| 42 | 
         
            +
                ):
         
     | 
| 43 | 
         
            +
                    assert exists(shape) ^ exists(shape_embed)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    if not exists(shape_embed):
         
     | 
| 46 | 
         
            +
                        point_feature = self.model.encode_latents(shape)
         
     | 
| 47 | 
         
            +
                        shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
         
     | 
| 48 | 
         
            +
                        shape_head = point_feature[:, 0:1]
         
     | 
| 49 | 
         
            +
                        shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
         
     | 
| 50 | 
         
            +
                        # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
         
     | 
| 51 | 
         
            +
                    return shape_head, shape_embed
         
     | 
    	
        primitive_anything/michelangelo/data/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
    	
        primitive_anything/michelangelo/data/templates.json
    ADDED
    
    | 
         @@ -0,0 +1,69 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
                "shape": [
         
     | 
| 3 | 
         
            +
                    "a point cloud model of {}.",
         
     | 
| 4 | 
         
            +
                    "There is a {} in the scene.",
         
     | 
| 5 | 
         
            +
                    "There is the {} in the scene.",
         
     | 
| 6 | 
         
            +
                    "a photo of a {} in the scene.",
         
     | 
| 7 | 
         
            +
                    "a photo of the {} in the scene.",
         
     | 
| 8 | 
         
            +
                    "a photo of one {} in the scene.",
         
     | 
| 9 | 
         
            +
                    "itap of a {}.",
         
     | 
| 10 | 
         
            +
                    "itap of my {}.",
         
     | 
| 11 | 
         
            +
                    "itap of the {}.",
         
     | 
| 12 | 
         
            +
                    "a photo of a {}.",
         
     | 
| 13 | 
         
            +
                    "a photo of my {}.",
         
     | 
| 14 | 
         
            +
                    "a photo of the {}.",
         
     | 
| 15 | 
         
            +
                    "a photo of one {}.",
         
     | 
| 16 | 
         
            +
                    "a photo of many {}.",
         
     | 
| 17 | 
         
            +
                    "a good photo of a {}.",
         
     | 
| 18 | 
         
            +
                    "a good photo of the {}.",
         
     | 
| 19 | 
         
            +
                    "a bad photo of a {}.",
         
     | 
| 20 | 
         
            +
                    "a bad photo of the {}.",
         
     | 
| 21 | 
         
            +
                    "a photo of a nice {}.",
         
     | 
| 22 | 
         
            +
                    "a photo of the nice {}.",
         
     | 
| 23 | 
         
            +
                    "a photo of a cool {}.",
         
     | 
| 24 | 
         
            +
                    "a photo of the cool {}.",
         
     | 
| 25 | 
         
            +
                    "a photo of a weird {}.",
         
     | 
| 26 | 
         
            +
                    "a photo of the weird {}.",
         
     | 
| 27 | 
         
            +
                    "a photo of a small {}.",
         
     | 
| 28 | 
         
            +
                    "a photo of the small {}.",
         
     | 
| 29 | 
         
            +
                    "a photo of a large {}.",
         
     | 
| 30 | 
         
            +
                    "a photo of the large {}.",
         
     | 
| 31 | 
         
            +
                    "a photo of a clean {}.",
         
     | 
| 32 | 
         
            +
                    "a photo of the clean {}.",
         
     | 
| 33 | 
         
            +
                    "a photo of a dirty {}.",
         
     | 
| 34 | 
         
            +
                    "a photo of the dirty {}.",
         
     | 
| 35 | 
         
            +
                    "a bright photo of a {}.",
         
     | 
| 36 | 
         
            +
                    "a bright photo of the {}.",
         
     | 
| 37 | 
         
            +
                    "a dark photo of a {}.",
         
     | 
| 38 | 
         
            +
                    "a dark photo of the {}.",
         
     | 
| 39 | 
         
            +
                    "a photo of a hard to see {}.",
         
     | 
| 40 | 
         
            +
                    "a photo of the hard to see {}.",
         
     | 
| 41 | 
         
            +
                    "a low resolution photo of a {}.",
         
     | 
| 42 | 
         
            +
                    "a low resolution photo of the {}.",
         
     | 
| 43 | 
         
            +
                    "a cropped photo of a {}.",
         
     | 
| 44 | 
         
            +
                    "a cropped photo of the {}.",
         
     | 
| 45 | 
         
            +
                    "a close-up photo of a {}.",
         
     | 
| 46 | 
         
            +
                    "a close-up photo of the {}.",
         
     | 
| 47 | 
         
            +
                    "a jpeg corrupted photo of a {}.",
         
     | 
| 48 | 
         
            +
                    "a jpeg corrupted photo of the {}.",
         
     | 
| 49 | 
         
            +
                    "a blurry photo of a {}.",
         
     | 
| 50 | 
         
            +
                    "a blurry photo of the {}.",
         
     | 
| 51 | 
         
            +
                    "a pixelated photo of a {}.",
         
     | 
| 52 | 
         
            +
                    "a pixelated photo of the {}.",
         
     | 
| 53 | 
         
            +
                    "a black and white photo of the {}.",
         
     | 
| 54 | 
         
            +
                    "a black and white photo of a {}",
         
     | 
| 55 | 
         
            +
                    "a plastic {}.",
         
     | 
| 56 | 
         
            +
                    "the plastic {}.",
         
     | 
| 57 | 
         
            +
                    "a toy {}.",
         
     | 
| 58 | 
         
            +
                    "the toy {}.",
         
     | 
| 59 | 
         
            +
                    "a plushie {}.",
         
     | 
| 60 | 
         
            +
                    "the plushie {}.",
         
     | 
| 61 | 
         
            +
                    "a cartoon {}.",
         
     | 
| 62 | 
         
            +
                    "the cartoon {}.",
         
     | 
| 63 | 
         
            +
                    "an embroidered {}.",
         
     | 
| 64 | 
         
            +
                    "the embroidered {}.",
         
     | 
| 65 | 
         
            +
                    "a painting of the {}.",
         
     | 
| 66 | 
         
            +
                    "a painting of a {}."
         
     | 
| 67 | 
         
            +
                ]
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            }
         
     | 
    	
        primitive_anything/michelangelo/data/transforms.py
    ADDED
    
    | 
         @@ -0,0 +1,407 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import time
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import warnings
         
     | 
| 6 | 
         
            +
            import random
         
     | 
| 7 | 
         
            +
            from omegaconf.listconfig import ListConfig
         
     | 
| 8 | 
         
            +
            from webdataset import pipelinefilter
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torchvision.transforms.functional as TVF
         
     | 
| 11 | 
         
            +
            from torchvision.transforms import InterpolationMode
         
     | 
| 12 | 
         
            +
            from torchvision.transforms.transforms import _interpolation_modes_from_int
         
     | 
| 13 | 
         
            +
            from typing import Sequence
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from ..utils import instantiate_from_config
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def _uid_buffer_pick(buf_dict, rng):
         
     | 
| 19 | 
         
            +
                uid_keys = list(buf_dict.keys())
         
     | 
| 20 | 
         
            +
                selected_uid = rng.choice(uid_keys)
         
     | 
| 21 | 
         
            +
                buf = buf_dict[selected_uid]
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                k = rng.randint(0, len(buf) - 1)
         
     | 
| 24 | 
         
            +
                sample = buf[k]
         
     | 
| 25 | 
         
            +
                buf[k] = buf[-1]
         
     | 
| 26 | 
         
            +
                buf.pop()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                if len(buf) == 0:
         
     | 
| 29 | 
         
            +
                    del buf_dict[selected_uid]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                return sample
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def _add_to_buf_dict(buf_dict, sample):
         
     | 
| 35 | 
         
            +
                key = sample["__key__"]
         
     | 
| 36 | 
         
            +
                uid, uid_sample_id = key.split("_")
         
     | 
| 37 | 
         
            +
                if uid not in buf_dict:
         
     | 
| 38 | 
         
            +
                    buf_dict[uid] = []
         
     | 
| 39 | 
         
            +
                buf_dict[uid].append(sample)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                return buf_dict
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
         
     | 
| 45 | 
         
            +
                """Shuffle the data in the stream.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                This uses a buffer of size `bufsize`. Shuffling at
         
     | 
| 48 | 
         
            +
                startup is less random; this is traded off against
         
     | 
| 49 | 
         
            +
                yielding samples quickly.
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                data: iterator
         
     | 
| 52 | 
         
            +
                bufsize: buffer size for shuffling
         
     | 
| 53 | 
         
            +
                returns: iterator
         
     | 
| 54 | 
         
            +
                rng: either random module or random.Random instance
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                if rng is None:
         
     | 
| 58 | 
         
            +
                    rng = random.Random(int((os.getpid() + time.time()) * 1e9))
         
     | 
| 59 | 
         
            +
                initial = min(initial, bufsize)
         
     | 
| 60 | 
         
            +
                buf_dict = dict()
         
     | 
| 61 | 
         
            +
                current_samples = 0
         
     | 
| 62 | 
         
            +
                for sample in data:
         
     | 
| 63 | 
         
            +
                    _add_to_buf_dict(buf_dict, sample)
         
     | 
| 64 | 
         
            +
                    current_samples += 1
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    if current_samples < bufsize:
         
     | 
| 67 | 
         
            +
                        try:
         
     | 
| 68 | 
         
            +
                            _add_to_buf_dict(buf_dict, next(data))  # skipcq: PYL-R1708
         
     | 
| 69 | 
         
            +
                            current_samples += 1
         
     | 
| 70 | 
         
            +
                        except StopIteration:
         
     | 
| 71 | 
         
            +
                            pass
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    if current_samples >= initial:
         
     | 
| 74 | 
         
            +
                        current_samples -= 1
         
     | 
| 75 | 
         
            +
                        yield _uid_buffer_pick(buf_dict, rng)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                while current_samples > 0:
         
     | 
| 78 | 
         
            +
                    current_samples -= 1
         
     | 
| 79 | 
         
            +
                    yield _uid_buffer_pick(buf_dict, rng)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            uid_shuffle = pipelinefilter(_uid_shuffle)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            class RandomSample(object):
         
     | 
| 86 | 
         
            +
                def __init__(self,
         
     | 
| 87 | 
         
            +
                             num_volume_samples: int = 1024,
         
     | 
| 88 | 
         
            +
                             num_near_samples: int = 1024):
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    super().__init__()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    self.num_volume_samples = num_volume_samples
         
     | 
| 93 | 
         
            +
                    self.num_near_samples = num_near_samples
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def __call__(self, sample):
         
     | 
| 96 | 
         
            +
                    rng = np.random.default_rng()
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    # 1. sample surface input
         
     | 
| 99 | 
         
            +
                    total_surface = sample["surface"]
         
     | 
| 100 | 
         
            +
                    ind = rng.choice(total_surface.shape[0], replace=False)
         
     | 
| 101 | 
         
            +
                    surface = total_surface[ind]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # 2. sample volume/near geometric points
         
     | 
| 104 | 
         
            +
                    vol_points = sample["vol_points"]
         
     | 
| 105 | 
         
            +
                    vol_label = sample["vol_label"]
         
     | 
| 106 | 
         
            +
                    near_points = sample["near_points"]
         
     | 
| 107 | 
         
            +
                    near_label = sample["near_label"]
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
         
     | 
| 110 | 
         
            +
                    vol_points = vol_points[ind]
         
     | 
| 111 | 
         
            +
                    vol_label = vol_label[ind]
         
     | 
| 112 | 
         
            +
                    vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
         
     | 
| 115 | 
         
            +
                    near_points = near_points[ind]
         
     | 
| 116 | 
         
            +
                    near_label = near_label[ind]
         
     | 
| 117 | 
         
            +
                    near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    # concat sampled volume and near points
         
     | 
| 120 | 
         
            +
                    geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    sample = {
         
     | 
| 123 | 
         
            +
                        "surface": surface,
         
     | 
| 124 | 
         
            +
                        "geo_points": geo_points
         
     | 
| 125 | 
         
            +
                    }
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    return sample
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            class SplitRandomSample(object):
         
     | 
| 131 | 
         
            +
                def __init__(self,
         
     | 
| 132 | 
         
            +
                             use_surface_sample: bool = False,
         
     | 
| 133 | 
         
            +
                             num_surface_samples: int = 4096,
         
     | 
| 134 | 
         
            +
                             num_volume_samples: int = 1024,
         
     | 
| 135 | 
         
            +
                             num_near_samples: int = 1024):
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    super().__init__()
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    self.use_surface_sample = use_surface_sample
         
     | 
| 140 | 
         
            +
                    self.num_surface_samples = num_surface_samples
         
     | 
| 141 | 
         
            +
                    self.num_volume_samples = num_volume_samples
         
     | 
| 142 | 
         
            +
                    self.num_near_samples = num_near_samples
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def __call__(self, sample):
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    rng = np.random.default_rng()
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    # 1. sample surface input
         
     | 
| 149 | 
         
            +
                    surface = sample["surface"]
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if self.use_surface_sample:
         
     | 
| 152 | 
         
            +
                        replace = surface.shape[0] < self.num_surface_samples
         
     | 
| 153 | 
         
            +
                        ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
         
     | 
| 154 | 
         
            +
                        surface = surface[ind]
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    # 2. sample volume/near geometric points
         
     | 
| 157 | 
         
            +
                    vol_points = sample["vol_points"]
         
     | 
| 158 | 
         
            +
                    vol_label = sample["vol_label"]
         
     | 
| 159 | 
         
            +
                    near_points = sample["near_points"]
         
     | 
| 160 | 
         
            +
                    near_label = sample["near_label"]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
         
     | 
| 163 | 
         
            +
                    vol_points = vol_points[ind]
         
     | 
| 164 | 
         
            +
                    vol_label = vol_label[ind]
         
     | 
| 165 | 
         
            +
                    vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
         
     | 
| 168 | 
         
            +
                    near_points = near_points[ind]
         
     | 
| 169 | 
         
            +
                    near_label = near_label[ind]
         
     | 
| 170 | 
         
            +
                    near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    # concat sampled volume and near points
         
     | 
| 173 | 
         
            +
                    geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    sample = {
         
     | 
| 176 | 
         
            +
                        "surface": surface,
         
     | 
| 177 | 
         
            +
                        "geo_points": geo_points
         
     | 
| 178 | 
         
            +
                    }
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    return sample
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            class FeatureSelection(object):
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                VALID_SURFACE_FEATURE_DIMS = {
         
     | 
| 186 | 
         
            +
                    "none": [0, 1, 2],                              # xyz
         
     | 
| 187 | 
         
            +
                    "watertight_normal": [0, 1, 2, 3, 4, 5],        # xyz, normal
         
     | 
| 188 | 
         
            +
                    "normal": [0, 1, 2, 6, 7, 8]
         
     | 
| 189 | 
         
            +
                }
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def __init__(self, surface_feature_type: str):
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    self.surface_feature_type = surface_feature_type
         
     | 
| 194 | 
         
            +
                    self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def __call__(self, sample):
         
     | 
| 197 | 
         
            +
                    sample["surface"] = sample["surface"][:, self.surface_dims]
         
     | 
| 198 | 
         
            +
                    return sample
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            class AxisScaleTransform(object):
         
     | 
| 202 | 
         
            +
                def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
         
     | 
| 203 | 
         
            +
                    assert isinstance(interval, (tuple, list, ListConfig))
         
     | 
| 204 | 
         
            +
                    self.interval = interval
         
     | 
| 205 | 
         
            +
                    self.min_val = interval[0]
         
     | 
| 206 | 
         
            +
                    self.max_val = interval[1]
         
     | 
| 207 | 
         
            +
                    self.inter_size = interval[1] - interval[0]
         
     | 
| 208 | 
         
            +
                    self.jitter = jitter
         
     | 
| 209 | 
         
            +
                    self.jitter_scale = jitter_scale
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                def __call__(self, sample):
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    surface = sample["surface"][..., 0:3]
         
     | 
| 214 | 
         
            +
                    geo_points = sample["geo_points"][..., 0:3]
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    scaling = torch.rand(1, 3) * self.inter_size + self.min_val
         
     | 
| 217 | 
         
            +
                    # print(scaling)
         
     | 
| 218 | 
         
            +
                    surface = surface * scaling
         
     | 
| 219 | 
         
            +
                    geo_points = geo_points * scaling
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    scale = (1 / torch.abs(surface).max().item()) * 0.999999
         
     | 
| 222 | 
         
            +
                    surface *= scale
         
     | 
| 223 | 
         
            +
                    geo_points *= scale
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    if self.jitter:
         
     | 
| 226 | 
         
            +
                        surface += self.jitter_scale * torch.randn_like(surface)
         
     | 
| 227 | 
         
            +
                        surface.clamp_(min=-1.015, max=1.015)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    sample["surface"][..., 0:3] = surface
         
     | 
| 230 | 
         
            +
                    sample["geo_points"][..., 0:3] = geo_points
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    return sample
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
            class ToTensor(object):
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
         
     | 
| 238 | 
         
            +
                    self.tensor_keys = tensor_keys
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def __call__(self, sample):
         
     | 
| 241 | 
         
            +
                    for key in self.tensor_keys:
         
     | 
| 242 | 
         
            +
                        if key not in sample:
         
     | 
| 243 | 
         
            +
                            continue
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                        sample[key] = torch.tensor(sample[key], dtype=torch.float32)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    return sample
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            class AxisScale(object):
         
     | 
| 251 | 
         
            +
                def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
         
     | 
| 252 | 
         
            +
                    assert isinstance(interval, (tuple, list, ListConfig))
         
     | 
| 253 | 
         
            +
                    self.interval = interval
         
     | 
| 254 | 
         
            +
                    self.jitter = jitter
         
     | 
| 255 | 
         
            +
                    self.jitter_scale = jitter_scale
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                def __call__(self, surface, *args):
         
     | 
| 258 | 
         
            +
                    scaling = torch.rand(1, 3) * 0.5 + 0.75
         
     | 
| 259 | 
         
            +
                    # print(scaling)
         
     | 
| 260 | 
         
            +
                    surface = surface * scaling
         
     | 
| 261 | 
         
            +
                    scale = (1 / torch.abs(surface).max().item()) * 0.999999
         
     | 
| 262 | 
         
            +
                    surface *= scale
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    args_outputs = []
         
     | 
| 265 | 
         
            +
                    for _arg in args:
         
     | 
| 266 | 
         
            +
                        _arg = _arg * scaling * scale
         
     | 
| 267 | 
         
            +
                        args_outputs.append(_arg)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    if self.jitter:
         
     | 
| 270 | 
         
            +
                        surface += self.jitter_scale * torch.randn_like(surface)
         
     | 
| 271 | 
         
            +
                        surface.clamp_(min=-1, max=1)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    if len(args) == 0:
         
     | 
| 274 | 
         
            +
                        return surface
         
     | 
| 275 | 
         
            +
                    else:
         
     | 
| 276 | 
         
            +
                        return surface, *args_outputs
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            class RandomResize(torch.nn.Module):
         
     | 
| 280 | 
         
            +
                """Apply randomly Resize with a given probability."""
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                def __init__(
         
     | 
| 283 | 
         
            +
                    self,
         
     | 
| 284 | 
         
            +
                    size,
         
     | 
| 285 | 
         
            +
                    resize_radio=(0.5, 1),
         
     | 
| 286 | 
         
            +
                    allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
         
     | 
| 287 | 
         
            +
                    interpolation=InterpolationMode.BICUBIC,
         
     | 
| 288 | 
         
            +
                    max_size=None,
         
     | 
| 289 | 
         
            +
                    antialias=None,
         
     | 
| 290 | 
         
            +
                ):
         
     | 
| 291 | 
         
            +
                    super().__init__()
         
     | 
| 292 | 
         
            +
                    if not isinstance(size, (int, Sequence)):
         
     | 
| 293 | 
         
            +
                        raise TypeError(f"Size should be int or sequence. Got {type(size)}")
         
     | 
| 294 | 
         
            +
                    if isinstance(size, Sequence) and len(size) not in (1, 2):
         
     | 
| 295 | 
         
            +
                        raise ValueError("If size is a sequence, it should have 1 or 2 values")
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    self.size = size
         
     | 
| 298 | 
         
            +
                    self.max_size = max_size
         
     | 
| 299 | 
         
            +
                    # Backward compatibility with integer value
         
     | 
| 300 | 
         
            +
                    if isinstance(interpolation, int):
         
     | 
| 301 | 
         
            +
                        warnings.warn(
         
     | 
| 302 | 
         
            +
                            "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
         
     | 
| 303 | 
         
            +
                            "Please use InterpolationMode enum."
         
     | 
| 304 | 
         
            +
                        )
         
     | 
| 305 | 
         
            +
                        interpolation = _interpolation_modes_from_int(interpolation)
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    self.interpolation = interpolation
         
     | 
| 308 | 
         
            +
                    self.antialias = antialias
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    self.resize_radio = resize_radio
         
     | 
| 311 | 
         
            +
                    self.allow_resize_interpolations = allow_resize_interpolations
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def random_resize_params(self):
         
     | 
| 314 | 
         
            +
                    radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    if isinstance(self.size, int):
         
     | 
| 317 | 
         
            +
                        size = int(self.size * radio)
         
     | 
| 318 | 
         
            +
                    elif isinstance(self.size, Sequence):
         
     | 
| 319 | 
         
            +
                        size = list(self.size)
         
     | 
| 320 | 
         
            +
                        size = (int(size[0] * radio), int(size[1] * radio))
         
     | 
| 321 | 
         
            +
                    else:
         
     | 
| 322 | 
         
            +
                        raise RuntimeError()
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    interpolation = self.allow_resize_interpolations[
         
     | 
| 325 | 
         
            +
                        torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
         
     | 
| 326 | 
         
            +
                    ]
         
     | 
| 327 | 
         
            +
                    return size, interpolation
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                def forward(self, img):
         
     | 
| 330 | 
         
            +
                    size, interpolation = self.random_resize_params()
         
     | 
| 331 | 
         
            +
                    img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
         
     | 
| 332 | 
         
            +
                    img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
         
     | 
| 333 | 
         
            +
                    return img
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                def __repr__(self) -> str:
         
     | 
| 336 | 
         
            +
                    detail = f"(size={self.size}, interpolation={self.interpolation.value},"
         
     | 
| 337 | 
         
            +
                    detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
         
     | 
| 338 | 
         
            +
                    return f"{self.__class__.__name__}{detail}"
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
            class Compose(object):
         
     | 
| 342 | 
         
            +
                """Composes several transforms together. This transform does not support torchscript.
         
     | 
| 343 | 
         
            +
                Please, see the note below.
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                Args:
         
     | 
| 346 | 
         
            +
                    transforms (list of ``Transform`` objects): list of transforms to compose.
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                Example:
         
     | 
| 349 | 
         
            +
                    >>> transforms.Compose([
         
     | 
| 350 | 
         
            +
                    >>>     transforms.CenterCrop(10),
         
     | 
| 351 | 
         
            +
                    >>>     transforms.ToTensor(),
         
     | 
| 352 | 
         
            +
                    >>> ])
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                .. note::
         
     | 
| 355 | 
         
            +
                    In order to script the transformations, please use ``torch.nn.Sequential`` as below.
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    >>> transforms = torch.nn.Sequential(
         
     | 
| 358 | 
         
            +
                    >>>     transforms.CenterCrop(10),
         
     | 
| 359 | 
         
            +
                    >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
         
     | 
| 360 | 
         
            +
                    >>> )
         
     | 
| 361 | 
         
            +
                    >>> scripted_transforms = torch.jit.script(transforms)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
         
     | 
| 364 | 
         
            +
                    `lambda` functions or ``PIL.Image``.
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                """
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                def __init__(self, transforms):
         
     | 
| 369 | 
         
            +
                    self.transforms = transforms
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                def __call__(self, *args):
         
     | 
| 372 | 
         
            +
                    for t in self.transforms:
         
     | 
| 373 | 
         
            +
                        args = t(*args)
         
     | 
| 374 | 
         
            +
                    return args
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                def __repr__(self):
         
     | 
| 377 | 
         
            +
                    format_string = self.__class__.__name__ + '('
         
     | 
| 378 | 
         
            +
                    for t in self.transforms:
         
     | 
| 379 | 
         
            +
                        format_string += '\n'
         
     | 
| 380 | 
         
            +
                        format_string += '    {0}'.format(t)
         
     | 
| 381 | 
         
            +
                    format_string += '\n)'
         
     | 
| 382 | 
         
            +
                    return format_string
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
            def identity(*args, **kwargs):
         
     | 
| 386 | 
         
            +
                if len(args) == 1:
         
     | 
| 387 | 
         
            +
                    return args[0]
         
     | 
| 388 | 
         
            +
                else:
         
     | 
| 389 | 
         
            +
                    return args
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
            def build_transforms(cfg):
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                if cfg is None:
         
     | 
| 395 | 
         
            +
                    return identity
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                transforms = []
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                for transform_name, cfg_instance in cfg.items():
         
     | 
| 400 | 
         
            +
                    transform_instance = instantiate_from_config(cfg_instance)
         
     | 
| 401 | 
         
            +
                    transforms.append(transform_instance)
         
     | 
| 402 | 
         
            +
                    print(f"Build transform: {transform_instance}")
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                transforms = Compose(transforms)
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                return transforms
         
     | 
| 407 | 
         
            +
             
     | 
    	
        primitive_anything/michelangelo/data/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,59 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def worker_init_fn(_):
         
     | 
| 8 | 
         
            +
                worker_info = torch.utils.data.get_worker_info()
         
     | 
| 9 | 
         
            +
                worker_id = worker_info.id
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                # dataset = worker_info.dataset
         
     | 
| 12 | 
         
            +
                # split_size = dataset.num_records // worker_info.num_workers
         
     | 
| 13 | 
         
            +
                # # reset num_records to the true number to retain reliable length information
         
     | 
| 14 | 
         
            +
                # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
         
     | 
| 15 | 
         
            +
                # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
         
     | 
| 16 | 
         
            +
                # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                return np.random.seed(np.random.get_state()[1][0] + worker_id)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def collation_fn(samples, combine_tensors=True, combine_scalars=True):
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                Args:
         
     | 
| 25 | 
         
            +
                    samples (list[dict]):
         
     | 
| 26 | 
         
            +
                    combine_tensors:
         
     | 
| 27 | 
         
            +
                    combine_scalars:
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                Returns:
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                result = {}
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                keys = samples[0].keys()
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                for key in keys:
         
     | 
| 38 | 
         
            +
                    result[key] = []
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                for sample in samples:
         
     | 
| 41 | 
         
            +
                    for key in keys:
         
     | 
| 42 | 
         
            +
                        val = sample[key]
         
     | 
| 43 | 
         
            +
                        result[key].append(val)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                for key in keys:
         
     | 
| 46 | 
         
            +
                    val_list = result[key]
         
     | 
| 47 | 
         
            +
                    if isinstance(val_list[0], (int, float)):
         
     | 
| 48 | 
         
            +
                        if combine_scalars:
         
     | 
| 49 | 
         
            +
                            result[key] = np.array(result[key])
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    elif isinstance(val_list[0], torch.Tensor):
         
     | 
| 52 | 
         
            +
                        if combine_tensors:
         
     | 
| 53 | 
         
            +
                            result[key] = torch.stack(val_list)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    elif isinstance(val_list[0], np.ndarray):
         
     | 
| 56 | 
         
            +
                        if combine_tensors:
         
     | 
| 57 | 
         
            +
                            result[key] = np.stack(val_list)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                return result
         
     | 
    	
        primitive_anything/michelangelo/graphics/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
    	
        primitive_anything/michelangelo/graphics/primitives/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .volume import generate_dense_grid_points
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from .mesh import (
         
     | 
| 6 | 
         
            +
                MeshOutput,
         
     | 
| 7 | 
         
            +
                save_obj,
         
     | 
| 8 | 
         
            +
                savemeshtes2
         
     | 
| 9 | 
         
            +
            )
         
     | 
    	
        primitive_anything/michelangelo/graphics/primitives/mesh.py
    ADDED
    
    | 
         @@ -0,0 +1,114 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import cv2
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import PIL.Image
         
     | 
| 7 | 
         
            +
            from typing import Optional
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import trimesh
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def save_obj(pointnp_px3, facenp_fx3, fname):
         
     | 
| 13 | 
         
            +
                fid = open(fname, "w")
         
     | 
| 14 | 
         
            +
                write_str = ""
         
     | 
| 15 | 
         
            +
                for pidx, p in enumerate(pointnp_px3):
         
     | 
| 16 | 
         
            +
                    pp = p
         
     | 
| 17 | 
         
            +
                    write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                for i, f in enumerate(facenp_fx3):
         
     | 
| 20 | 
         
            +
                    f1 = f + 1
         
     | 
| 21 | 
         
            +
                    write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
         
     | 
| 22 | 
         
            +
                fid.write(write_str)
         
     | 
| 23 | 
         
            +
                fid.close()
         
     | 
| 24 | 
         
            +
                return
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
         
     | 
| 28 | 
         
            +
                fol, na = os.path.split(fname)
         
     | 
| 29 | 
         
            +
                na, _ = os.path.splitext(na)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                matname = "%s/%s.mtl" % (fol, na)
         
     | 
| 32 | 
         
            +
                fid = open(matname, "w")
         
     | 
| 33 | 
         
            +
                fid.write("newmtl material_0\n")
         
     | 
| 34 | 
         
            +
                fid.write("Kd 1 1 1\n")
         
     | 
| 35 | 
         
            +
                fid.write("Ka 0 0 0\n")
         
     | 
| 36 | 
         
            +
                fid.write("Ks 0.4 0.4 0.4\n")
         
     | 
| 37 | 
         
            +
                fid.write("Ns 10\n")
         
     | 
| 38 | 
         
            +
                fid.write("illum 2\n")
         
     | 
| 39 | 
         
            +
                fid.write("map_Kd %s.png\n" % na)
         
     | 
| 40 | 
         
            +
                fid.close()
         
     | 
| 41 | 
         
            +
                ####
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                fid = open(fname, "w")
         
     | 
| 44 | 
         
            +
                fid.write("mtllib %s.mtl\n" % na)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                for pidx, p in enumerate(pointnp_px3):
         
     | 
| 47 | 
         
            +
                    pp = p
         
     | 
| 48 | 
         
            +
                    fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                for pidx, p in enumerate(tcoords_px2):
         
     | 
| 51 | 
         
            +
                    pp = p
         
     | 
| 52 | 
         
            +
                    fid.write("vt %f %f\n" % (pp[0], pp[1]))
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                fid.write("usemtl material_0\n")
         
     | 
| 55 | 
         
            +
                for i, f in enumerate(facenp_fx3):
         
     | 
| 56 | 
         
            +
                    f1 = f + 1
         
     | 
| 57 | 
         
            +
                    f2 = facetex_fx3[i] + 1
         
     | 
| 58 | 
         
            +
                    fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
         
     | 
| 59 | 
         
            +
                fid.close()
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
         
     | 
| 62 | 
         
            +
                    os.path.join(fol, "%s.png" % na))
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                return
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            class MeshOutput(object):
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def __init__(self,
         
     | 
| 70 | 
         
            +
                             mesh_v: np.ndarray,
         
     | 
| 71 | 
         
            +
                             mesh_f: np.ndarray,
         
     | 
| 72 | 
         
            +
                             vertex_colors: Optional[np.ndarray] = None,
         
     | 
| 73 | 
         
            +
                             uvs: Optional[np.ndarray] = None,
         
     | 
| 74 | 
         
            +
                             mesh_tex_idx: Optional[np.ndarray] = None,
         
     | 
| 75 | 
         
            +
                             tex_map: Optional[np.ndarray] = None):
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    self.mesh_v = mesh_v
         
     | 
| 78 | 
         
            +
                    self.mesh_f = mesh_f
         
     | 
| 79 | 
         
            +
                    self.vertex_colors = vertex_colors
         
     | 
| 80 | 
         
            +
                    self.uvs = uvs
         
     | 
| 81 | 
         
            +
                    self.mesh_tex_idx = mesh_tex_idx
         
     | 
| 82 | 
         
            +
                    self.tex_map = tex_map
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def contain_uv_texture(self):
         
     | 
| 85 | 
         
            +
                    return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def contain_vertex_colors(self):
         
     | 
| 88 | 
         
            +
                    return self.vertex_colors is not None
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def export(self, fname):
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if self.contain_uv_texture():
         
     | 
| 93 | 
         
            +
                        savemeshtes2(
         
     | 
| 94 | 
         
            +
                            self.mesh_v,
         
     | 
| 95 | 
         
            +
                            self.uvs,
         
     | 
| 96 | 
         
            +
                            self.mesh_f,
         
     | 
| 97 | 
         
            +
                            self.mesh_tex_idx,
         
     | 
| 98 | 
         
            +
                            self.tex_map,
         
     | 
| 99 | 
         
            +
                            fname
         
     | 
| 100 | 
         
            +
                        )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    elif self.contain_vertex_colors():
         
     | 
| 103 | 
         
            +
                        mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
         
     | 
| 104 | 
         
            +
                        mesh_obj.export(fname)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    else:
         
     | 
| 107 | 
         
            +
                        save_obj(
         
     | 
| 108 | 
         
            +
                            self.mesh_v,
         
     | 
| 109 | 
         
            +
                            self.mesh_f,
         
     | 
| 110 | 
         
            +
                            fname
         
     | 
| 111 | 
         
            +
                        )
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
    	
        primitive_anything/michelangelo/graphics/primitives/volume.py
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            def generate_dense_grid_points(bbox_min: np.ndarray,
         
     | 
| 7 | 
         
            +
                                           bbox_max: np.ndarray,
         
     | 
| 8 | 
         
            +
                                           octree_depth: int,
         
     | 
| 9 | 
         
            +
                                           indexing: str = "ij"):
         
     | 
| 10 | 
         
            +
                length = bbox_max - bbox_min
         
     | 
| 11 | 
         
            +
                num_cells = np.exp2(octree_depth)
         
     | 
| 12 | 
         
            +
                x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
         
     | 
| 13 | 
         
            +
                y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
         
     | 
| 14 | 
         
            +
                z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
         
     | 
| 15 | 
         
            +
                [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
         
     | 
| 16 | 
         
            +
                xyz = np.stack((xs, ys, zs), axis=-1)
         
     | 
| 17 | 
         
            +
                xyz = xyz.reshape(-1, 3)
         
     | 
| 18 | 
         
            +
                grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                return xyz, grid_size, length
         
     | 
| 21 | 
         
            +
             
     | 
    	
        primitive_anything/michelangelo/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py
    ADDED
    
    | 
         @@ -0,0 +1,483 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from omegaconf import DictConfig
         
     | 
| 4 | 
         
            +
            from typing import List, Tuple, Dict, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
            from torch.optim import lr_scheduler
         
     | 
| 10 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 11 | 
         
            +
            from pytorch_lightning.utilities import rank_zero_only
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from einops import rearrange
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from diffusers.schedulers import (
         
     | 
| 16 | 
         
            +
                DDPMScheduler,
         
     | 
| 17 | 
         
            +
                DDIMScheduler,
         
     | 
| 18 | 
         
            +
                KarrasVeScheduler,
         
     | 
| 19 | 
         
            +
                DPMSolverMultistepScheduler
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from ...utils import instantiate_from_config
         
     | 
| 23 | 
         
            +
            # from ..tsal.tsal_base import ShapeAsLatentPLModule
         
     | 
| 24 | 
         
            +
            from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
         
     | 
| 25 | 
         
            +
            from .inference_utils import ddim_sample
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def disabled_train(self, mode=True):
         
     | 
| 31 | 
         
            +
                """Overwrite model.train with this function to make sure train/eval mode
         
     | 
| 32 | 
         
            +
                does not change anymore."""
         
     | 
| 33 | 
         
            +
                return self
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class ASLDiffuser(pl.LightningModule):
         
     | 
| 37 | 
         
            +
                first_stage_model: Optional[AlignedShapeAsLatentPLModule]
         
     | 
| 38 | 
         
            +
                # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
         
     | 
| 39 | 
         
            +
                model: nn.Module
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(self, *,
         
     | 
| 42 | 
         
            +
                             first_stage_config,
         
     | 
| 43 | 
         
            +
                             denoiser_cfg,
         
     | 
| 44 | 
         
            +
                             scheduler_cfg,
         
     | 
| 45 | 
         
            +
                             optimizer_cfg,
         
     | 
| 46 | 
         
            +
                             loss_cfg,
         
     | 
| 47 | 
         
            +
                             first_stage_key: str = "surface",
         
     | 
| 48 | 
         
            +
                             cond_stage_key: str = "image",
         
     | 
| 49 | 
         
            +
                             cond_stage_trainable: bool = True,
         
     | 
| 50 | 
         
            +
                             scale_by_std: bool = False,
         
     | 
| 51 | 
         
            +
                             z_scale_factor: float = 1.0,
         
     | 
| 52 | 
         
            +
                             ckpt_path: Optional[str] = None,
         
     | 
| 53 | 
         
            +
                             ignore_keys: Union[Tuple[str], List[str]] = ()):
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    super().__init__()
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    self.first_stage_key = first_stage_key
         
     | 
| 58 | 
         
            +
                    self.cond_stage_key = cond_stage_key
         
     | 
| 59 | 
         
            +
                    self.cond_stage_trainable = cond_stage_trainable
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    # 1. initialize first stage. 
         
     | 
| 62 | 
         
            +
                    # Note: the condition model contained in the first stage model.
         
     | 
| 63 | 
         
            +
                    self.first_stage_config = first_stage_config
         
     | 
| 64 | 
         
            +
                    self.first_stage_model = None
         
     | 
| 65 | 
         
            +
                    # self.instantiate_first_stage(first_stage_config)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    # 2. initialize conditional stage
         
     | 
| 68 | 
         
            +
                    # self.instantiate_cond_stage(cond_stage_config)
         
     | 
| 69 | 
         
            +
                    self.cond_stage_model = {
         
     | 
| 70 | 
         
            +
                        "image": self.encode_image,
         
     | 
| 71 | 
         
            +
                        "image_unconditional_embedding": self.empty_img_cond,
         
     | 
| 72 | 
         
            +
                        "text": self.encode_text,
         
     | 
| 73 | 
         
            +
                        "text_unconditional_embedding": self.empty_text_cond,
         
     | 
| 74 | 
         
            +
                        "surface": self.encode_surface,
         
     | 
| 75 | 
         
            +
                        "surface_unconditional_embedding": self.empty_surface_cond,
         
     | 
| 76 | 
         
            +
                    }
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # 3. diffusion model
         
     | 
| 79 | 
         
            +
                    self.model = instantiate_from_config(
         
     | 
| 80 | 
         
            +
                        denoiser_cfg, device=None, dtype=None
         
     | 
| 81 | 
         
            +
                    )
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    self.optimizer_cfg = optimizer_cfg
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # 4. scheduling strategy
         
     | 
| 86 | 
         
            +
                    self.scheduler_cfg = scheduler_cfg
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
         
     | 
| 89 | 
         
            +
                    self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # 5. loss configures
         
     | 
| 92 | 
         
            +
                    self.loss_cfg = loss_cfg
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    self.scale_by_std = scale_by_std
         
     | 
| 95 | 
         
            +
                    if scale_by_std:
         
     | 
| 96 | 
         
            +
                        self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
         
     | 
| 97 | 
         
            +
                    else:
         
     | 
| 98 | 
         
            +
                        self.z_scale_factor = z_scale_factor
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    self.ckpt_path = ckpt_path
         
     | 
| 101 | 
         
            +
                    if ckpt_path is not None:
         
     | 
| 102 | 
         
            +
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def instantiate_first_stage(self, config):
         
     | 
| 105 | 
         
            +
                    model = instantiate_from_config(config)
         
     | 
| 106 | 
         
            +
                    self.first_stage_model = model.eval()
         
     | 
| 107 | 
         
            +
                    self.first_stage_model.train = disabled_train
         
     | 
| 108 | 
         
            +
                    for param in self.first_stage_model.parameters():
         
     | 
| 109 | 
         
            +
                        param.requires_grad = False
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    self.first_stage_model = self.first_stage_model.to(self.device)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                # def instantiate_cond_stage(self, config):
         
     | 
| 114 | 
         
            +
                #     if not self.cond_stage_trainable:
         
     | 
| 115 | 
         
            +
                #         if config == "__is_first_stage__":
         
     | 
| 116 | 
         
            +
                #             print("Using first stage also as cond stage.")
         
     | 
| 117 | 
         
            +
                #             self.cond_stage_model = self.first_stage_model
         
     | 
| 118 | 
         
            +
                #         elif config == "__is_unconditional__":
         
     | 
| 119 | 
         
            +
                #             print(f"Training {self.__class__.__name__} as an unconditional model.")
         
     | 
| 120 | 
         
            +
                #             self.cond_stage_model = None
         
     | 
| 121 | 
         
            +
                #             # self.be_unconditional = True
         
     | 
| 122 | 
         
            +
                #         else:
         
     | 
| 123 | 
         
            +
                #             model = instantiate_from_config(config)
         
     | 
| 124 | 
         
            +
                #             self.cond_stage_model = model.eval()
         
     | 
| 125 | 
         
            +
                #             self.cond_stage_model.train = disabled_train
         
     | 
| 126 | 
         
            +
                #             for param in self.cond_stage_model.parameters():
         
     | 
| 127 | 
         
            +
                #                 param.requires_grad = False
         
     | 
| 128 | 
         
            +
                #     else:
         
     | 
| 129 | 
         
            +
                #         assert config != "__is_first_stage__"
         
     | 
| 130 | 
         
            +
                #         assert config != "__is_unconditional__"
         
     | 
| 131 | 
         
            +
                #         model = instantiate_from_config(config)
         
     | 
| 132 | 
         
            +
                #         self.cond_stage_model = model
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                def init_from_ckpt(self, path, ignore_keys=()):
         
     | 
| 135 | 
         
            +
                    state_dict = torch.load(path, map_location="cpu")["state_dict"]
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    keys = list(state_dict.keys())
         
     | 
| 138 | 
         
            +
                    for k in keys:
         
     | 
| 139 | 
         
            +
                        for ik in ignore_keys:
         
     | 
| 140 | 
         
            +
                            if k.startswith(ik):
         
     | 
| 141 | 
         
            +
                                print("Deleting key {} from state_dict.".format(k))
         
     | 
| 142 | 
         
            +
                                del state_dict[k]
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    missing, unexpected = self.load_state_dict(state_dict, strict=False)
         
     | 
| 145 | 
         
            +
                    print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
         
     | 
| 146 | 
         
            +
                    if len(missing) > 0:
         
     | 
| 147 | 
         
            +
                        print(f"Missing Keys: {missing}")
         
     | 
| 148 | 
         
            +
                        print(f"Unexpected Keys: {unexpected}")
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                @property
         
     | 
| 151 | 
         
            +
                def zero_rank(self):
         
     | 
| 152 | 
         
            +
                    if self._trainer:
         
     | 
| 153 | 
         
            +
                        zero_rank = self.trainer.local_rank == 0
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        zero_rank = True
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    return zero_rank
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def configure_optimizers(self) -> Tuple[List, List]:
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    lr = self.learning_rate
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    trainable_parameters = list(self.model.parameters())
         
     | 
| 164 | 
         
            +
                    # if the conditional encoder is trainable
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    # if self.cond_stage_trainable:
         
     | 
| 167 | 
         
            +
                    #     conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
         
     | 
| 168 | 
         
            +
                    #     trainable_parameters += conditioner_params
         
     | 
| 169 | 
         
            +
                    #     print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    if self.optimizer_cfg is None:
         
     | 
| 172 | 
         
            +
                        optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
         
     | 
| 173 | 
         
            +
                        schedulers = []
         
     | 
| 174 | 
         
            +
                    else:
         
     | 
| 175 | 
         
            +
                        optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
         
     | 
| 176 | 
         
            +
                        scheduler_func = instantiate_from_config(
         
     | 
| 177 | 
         
            +
                            self.optimizer_cfg.scheduler,
         
     | 
| 178 | 
         
            +
                            max_decay_steps=self.trainer.max_steps,
         
     | 
| 179 | 
         
            +
                            lr_max=lr
         
     | 
| 180 | 
         
            +
                        )
         
     | 
| 181 | 
         
            +
                        scheduler = {
         
     | 
| 182 | 
         
            +
                            "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
         
     | 
| 183 | 
         
            +
                            "interval": "step",
         
     | 
| 184 | 
         
            +
                            "frequency": 1
         
     | 
| 185 | 
         
            +
                        }
         
     | 
| 186 | 
         
            +
                        optimizers = [optimizer]
         
     | 
| 187 | 
         
            +
                        schedulers = [scheduler]
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    return optimizers, schedulers
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                @torch.no_grad()
         
     | 
| 192 | 
         
            +
                def encode_text(self, text):
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    b = text.shape[0]
         
     | 
| 195 | 
         
            +
                    text_tokens = rearrange(text, "b t l -> (b t) l")
         
     | 
| 196 | 
         
            +
                    text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
         
     | 
| 197 | 
         
            +
                    text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
         
     | 
| 198 | 
         
            +
                    text_embed = text_embed.mean(dim=1)
         
     | 
| 199 | 
         
            +
                    text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    return text_embed
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                @torch.no_grad()
         
     | 
| 204 | 
         
            +
                def encode_image(self, img):
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    return self.first_stage_model.model.encode_image_embed(img)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                @torch.no_grad()
         
     | 
| 209 | 
         
            +
                def encode_surface(self, surface):
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                @torch.no_grad()
         
     | 
| 214 | 
         
            +
                def empty_text_cond(self, cond):
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    return torch.zeros_like(cond, device=cond.device)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                @torch.no_grad()
         
     | 
| 219 | 
         
            +
                def empty_img_cond(self, cond):
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    return torch.zeros_like(cond, device=cond.device)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                @torch.no_grad()
         
     | 
| 224 | 
         
            +
                def empty_surface_cond(self, cond):
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    return torch.zeros_like(cond, device=cond.device)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                @torch.no_grad()
         
     | 
| 229 | 
         
            +
                def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    z_q = self.first_stage_model.encode(surface, sample_posterior)
         
     | 
| 232 | 
         
            +
                    z_q = self.z_scale_factor * z_q
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    return z_q
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                @torch.no_grad()
         
     | 
| 237 | 
         
            +
                def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    z_q = 1. / self.z_scale_factor * z_q
         
     | 
| 240 | 
         
            +
                    latents = self.first_stage_model.decode(z_q, **kwargs)
         
     | 
| 241 | 
         
            +
                    return latents
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                @rank_zero_only
         
     | 
| 244 | 
         
            +
                @torch.no_grad()
         
     | 
| 245 | 
         
            +
                def on_train_batch_start(self, batch, batch_idx):
         
     | 
| 246 | 
         
            +
                    # only for very first batch
         
     | 
| 247 | 
         
            +
                    if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
         
     | 
| 248 | 
         
            +
                            and batch_idx == 0 and self.ckpt_path is None:
         
     | 
| 249 | 
         
            +
                        # set rescale weight to 1./std of encodings
         
     | 
| 250 | 
         
            +
                        print("### USING STD-RESCALING ###")
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                        z_q = self.encode_first_stage(batch[self.first_stage_key])
         
     | 
| 253 | 
         
            +
                        z = z_q.detach()
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                        del self.z_scale_factor
         
     | 
| 256 | 
         
            +
                        self.register_buffer("z_scale_factor", 1. / z.flatten().std())
         
     | 
| 257 | 
         
            +
                        print(f"setting self.z_scale_factor to {self.z_scale_factor}")
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                        print("### USING STD-RESCALING ###")
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                def compute_loss(self, model_outputs, split):
         
     | 
| 262 | 
         
            +
                    """
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    Args:
         
     | 
| 265 | 
         
            +
                        model_outputs (dict):
         
     | 
| 266 | 
         
            +
                            - x_0:
         
     | 
| 267 | 
         
            +
                            - noise:
         
     | 
| 268 | 
         
            +
                            - noise_prior:
         
     | 
| 269 | 
         
            +
                            - noise_pred:
         
     | 
| 270 | 
         
            +
                            - noise_pred_prior:
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                        split (str):
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    Returns:
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    """
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    pred = model_outputs["pred"]
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    if self.noise_scheduler.prediction_type == "epsilon":
         
     | 
| 281 | 
         
            +
                        target = model_outputs["noise"]
         
     | 
| 282 | 
         
            +
                    elif self.noise_scheduler.prediction_type == "sample":
         
     | 
| 283 | 
         
            +
                        target = model_outputs["x_0"]
         
     | 
| 284 | 
         
            +
                    else:
         
     | 
| 285 | 
         
            +
                        raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    if self.loss_cfg.loss_type == "l1":
         
     | 
| 288 | 
         
            +
                        simple = F.l1_loss(pred, target, reduction="mean")
         
     | 
| 289 | 
         
            +
                    elif self.loss_cfg.loss_type in ["mse", "l2"]:
         
     | 
| 290 | 
         
            +
                        simple = F.mse_loss(pred, target, reduction="mean")
         
     | 
| 291 | 
         
            +
                    else:
         
     | 
| 292 | 
         
            +
                        raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    total_loss = simple
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    loss_dict = {
         
     | 
| 297 | 
         
            +
                        f"{split}/total_loss": total_loss.clone().detach(),
         
     | 
| 298 | 
         
            +
                        f"{split}/simple": simple.detach(),
         
     | 
| 299 | 
         
            +
                    }
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    return total_loss, loss_dict
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def forward(self, batch):
         
     | 
| 304 | 
         
            +
                    """
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    Args:
         
     | 
| 307 | 
         
            +
                        batch:
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    Returns:
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    """
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    if self.first_stage_model is None:
         
     | 
| 314 | 
         
            +
                        self.instantiate_first_stage(self.first_stage_config)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    latents = self.encode_first_stage(batch[self.first_stage_key])
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
         
     | 
| 323 | 
         
            +
                    conditions = conditions * mask.to(conditions)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    # Sample noise that we"ll add to the latents
         
     | 
| 326 | 
         
            +
                    # [batch_size, n_token, latent_dim]
         
     | 
| 327 | 
         
            +
                    noise = torch.randn_like(latents)
         
     | 
| 328 | 
         
            +
                    bs = latents.shape[0]
         
     | 
| 329 | 
         
            +
                    # Sample a random timestep for each motion
         
     | 
| 330 | 
         
            +
                    timesteps = torch.randint(
         
     | 
| 331 | 
         
            +
                        0,
         
     | 
| 332 | 
         
            +
                        self.noise_scheduler.config.num_train_timesteps,
         
     | 
| 333 | 
         
            +
                        (bs,),
         
     | 
| 334 | 
         
            +
                        device=latents.device,
         
     | 
| 335 | 
         
            +
                    )
         
     | 
| 336 | 
         
            +
                    timesteps = timesteps.long()
         
     | 
| 337 | 
         
            +
                    # Add noise to the latents according to the noise magnitude at each timestep
         
     | 
| 338 | 
         
            +
                    noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    # diffusion model forward
         
     | 
| 341 | 
         
            +
                    noise_pred = self.model(noisy_z, timesteps, conditions)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    diffusion_outputs = {
         
     | 
| 344 | 
         
            +
                        "x_0": noisy_z,
         
     | 
| 345 | 
         
            +
                        "noise": noise,
         
     | 
| 346 | 
         
            +
                        "pred": noise_pred
         
     | 
| 347 | 
         
            +
                    }
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    return diffusion_outputs
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
         
     | 
| 352 | 
         
            +
                                  batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
         
     | 
| 353 | 
         
            +
                    """
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    Args:
         
     | 
| 356 | 
         
            +
                        batch (dict): the batch sample, and it contains:
         
     | 
| 357 | 
         
            +
                            - surface (torch.FloatTensor):
         
     | 
| 358 | 
         
            +
                            - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
         
     | 
| 359 | 
         
            +
                            - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
         
     | 
| 360 | 
         
            +
                            - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
         
     | 
| 361 | 
         
            +
                            - text (list of str):
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                        batch_idx (int):
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                        optimizer_idx (int):
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    Returns:
         
     | 
| 368 | 
         
            +
                        loss (torch.FloatTensor):
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    """
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    diffusion_outputs = self(batch)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
         
     | 
| 375 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    return loss
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                def validation_step(self, batch: Dict[str, torch.FloatTensor],
         
     | 
| 380 | 
         
            +
                                    batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
         
     | 
| 381 | 
         
            +
                    """
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    Args:
         
     | 
| 384 | 
         
            +
                        batch (dict): the batch sample, and it contains:
         
     | 
| 385 | 
         
            +
                            - surface_pc (torch.FloatTensor): [n_pts, 4]
         
     | 
| 386 | 
         
            +
                            - surface_feats (torch.FloatTensor): [n_pts, c]
         
     | 
| 387 | 
         
            +
                            - text (list of str):
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                        batch_idx (int):
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                        optimizer_idx (int):
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                    Returns:
         
     | 
| 394 | 
         
            +
                        loss (torch.FloatTensor):
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    """
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                    diffusion_outputs = self(batch)
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                    loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
         
     | 
| 401 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    return loss
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                @torch.no_grad()
         
     | 
| 406 | 
         
            +
                def sample(self,
         
     | 
| 407 | 
         
            +
                           batch: Dict[str, Union[torch.FloatTensor, List[str]]],
         
     | 
| 408 | 
         
            +
                           sample_times: int = 1,
         
     | 
| 409 | 
         
            +
                           steps: Optional[int] = None,
         
     | 
| 410 | 
         
            +
                           guidance_scale: Optional[float] = None,
         
     | 
| 411 | 
         
            +
                           eta: float = 0.0,
         
     | 
| 412 | 
         
            +
                           return_intermediates: bool = False, **kwargs):
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    if self.first_stage_model is None:
         
     | 
| 415 | 
         
            +
                        self.instantiate_first_stage(self.first_stage_config)
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                    if steps is None:
         
     | 
| 418 | 
         
            +
                        steps = self.scheduler_cfg.num_inference_steps
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                    if guidance_scale is None:
         
     | 
| 421 | 
         
            +
                        guidance_scale = self.scheduler_cfg.guidance_scale
         
     | 
| 422 | 
         
            +
                    do_classifier_free_guidance = guidance_scale > 0
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    # conditional encode
         
     | 
| 425 | 
         
            +
                    xc = batch[self.cond_stage_key]
         
     | 
| 426 | 
         
            +
                    # cond = self.cond_stage_model[self.cond_stage_key](xc)
         
     | 
| 427 | 
         
            +
                    cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    if do_classifier_free_guidance:
         
     | 
| 430 | 
         
            +
                        """
         
     | 
| 431 | 
         
            +
                        Note: There are two kinds of uncond for text. 
         
     | 
| 432 | 
         
            +
                        1: using "" as uncond text; (in SAL diffusion)
         
     | 
| 433 | 
         
            +
                        2: zeros_like(cond) as uncond text; (in MDM)
         
     | 
| 434 | 
         
            +
                        """
         
     | 
| 435 | 
         
            +
                        # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
         
     | 
| 436 | 
         
            +
                        un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
         
     | 
| 437 | 
         
            +
                        # un_cond = torch.zeros_like(cond, device=cond.device)
         
     | 
| 438 | 
         
            +
                        cond = torch.cat([un_cond, cond], dim=0)
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                    outputs = []
         
     | 
| 441 | 
         
            +
                    latents = None
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    if not return_intermediates:
         
     | 
| 444 | 
         
            +
                        for _ in range(sample_times):
         
     | 
| 445 | 
         
            +
                            sample_loop = ddim_sample(
         
     | 
| 446 | 
         
            +
                                self.denoise_scheduler,
         
     | 
| 447 | 
         
            +
                                self.model,
         
     | 
| 448 | 
         
            +
                                shape=self.first_stage_model.latent_shape,
         
     | 
| 449 | 
         
            +
                                cond=cond,
         
     | 
| 450 | 
         
            +
                                steps=steps,
         
     | 
| 451 | 
         
            +
                                guidance_scale=guidance_scale,
         
     | 
| 452 | 
         
            +
                                do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 453 | 
         
            +
                                device=self.device,
         
     | 
| 454 | 
         
            +
                                eta=eta,
         
     | 
| 455 | 
         
            +
                                disable_prog=not self.zero_rank
         
     | 
| 456 | 
         
            +
                            )
         
     | 
| 457 | 
         
            +
                            for sample, t in sample_loop:
         
     | 
| 458 | 
         
            +
                                latents = sample
         
     | 
| 459 | 
         
            +
                            outputs.append(self.decode_first_stage(latents, **kwargs))
         
     | 
| 460 | 
         
            +
                    else:
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                        sample_loop = ddim_sample(
         
     | 
| 463 | 
         
            +
                            self.denoise_scheduler,
         
     | 
| 464 | 
         
            +
                            self.model,
         
     | 
| 465 | 
         
            +
                            shape=self.first_stage_model.latent_shape,
         
     | 
| 466 | 
         
            +
                            cond=cond,
         
     | 
| 467 | 
         
            +
                            steps=steps,
         
     | 
| 468 | 
         
            +
                            guidance_scale=guidance_scale,
         
     | 
| 469 | 
         
            +
                            do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 470 | 
         
            +
                            device=self.device,
         
     | 
| 471 | 
         
            +
                            eta=eta,
         
     | 
| 472 | 
         
            +
                            disable_prog=not self.zero_rank
         
     | 
| 473 | 
         
            +
                        )
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                        iter_size = steps // sample_times
         
     | 
| 476 | 
         
            +
                        i = 0
         
     | 
| 477 | 
         
            +
                        for sample, t in sample_loop:
         
     | 
| 478 | 
         
            +
                            latents = sample
         
     | 
| 479 | 
         
            +
                            if i % iter_size == 0 or i == steps - 1:
         
     | 
| 480 | 
         
            +
                                outputs.append(self.decode_first_stage(latents, **kwargs))
         
     | 
| 481 | 
         
            +
                            i += 1
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
                    return outputs
         
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from typing import Optional
         
     | 
| 6 | 
         
            +
            from diffusers.models.embeddings import Timesteps
         
     | 
| 7 | 
         
            +
            import math
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from ..modules.transformer_blocks import MLP
         
     | 
| 10 | 
         
            +
            from ..modules.diffusion_transformer import UNetDiffusionTransformer
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class ConditionalASLUDTDenoiser(nn.Module):
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def __init__(self, *,
         
     | 
| 16 | 
         
            +
                             device: Optional[torch.device],
         
     | 
| 17 | 
         
            +
                             dtype: Optional[torch.dtype],
         
     | 
| 18 | 
         
            +
                             input_channels: int,
         
     | 
| 19 | 
         
            +
                             output_channels: int,
         
     | 
| 20 | 
         
            +
                             n_ctx: int,
         
     | 
| 21 | 
         
            +
                             width: int,
         
     | 
| 22 | 
         
            +
                             layers: int,
         
     | 
| 23 | 
         
            +
                             heads: int,
         
     | 
| 24 | 
         
            +
                             context_dim: int,
         
     | 
| 25 | 
         
            +
                             context_ln: bool = True,
         
     | 
| 26 | 
         
            +
                             skip_ln: bool = False,
         
     | 
| 27 | 
         
            +
                             init_scale: float = 0.25,
         
     | 
| 28 | 
         
            +
                             flip_sin_to_cos: bool = False,
         
     | 
| 29 | 
         
            +
                             use_checkpoint: bool = False):
         
     | 
| 30 | 
         
            +
                    super().__init__()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    self.use_checkpoint = use_checkpoint
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    init_scale = init_scale * math.sqrt(1.0 / width)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    self.backbone = UNetDiffusionTransformer(
         
     | 
| 37 | 
         
            +
                        device=device,
         
     | 
| 38 | 
         
            +
                        dtype=dtype,
         
     | 
| 39 | 
         
            +
                        n_ctx=n_ctx,
         
     | 
| 40 | 
         
            +
                        width=width,
         
     | 
| 41 | 
         
            +
                        layers=layers,
         
     | 
| 42 | 
         
            +
                        heads=heads,
         
     | 
| 43 | 
         
            +
                        skip_ln=skip_ln,
         
     | 
| 44 | 
         
            +
                        init_scale=init_scale,
         
     | 
| 45 | 
         
            +
                        use_checkpoint=use_checkpoint
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
                    self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
         
     | 
| 48 | 
         
            +
                    self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
         
     | 
| 49 | 
         
            +
                    self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # timestep embedding
         
     | 
| 52 | 
         
            +
                    self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
         
     | 
| 53 | 
         
            +
                    self.time_proj = MLP(
         
     | 
| 54 | 
         
            +
                        device=device, dtype=dtype, width=width, init_scale=init_scale
         
     | 
| 55 | 
         
            +
                    )
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    self.context_embed = nn.Sequential(
         
     | 
| 58 | 
         
            +
                        nn.LayerNorm(context_dim, device=device, dtype=dtype),
         
     | 
| 59 | 
         
            +
                        nn.Linear(context_dim, width, device=device, dtype=dtype),
         
     | 
| 60 | 
         
            +
                    )
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    if context_ln:
         
     | 
| 63 | 
         
            +
                        self.context_embed = nn.Sequential(
         
     | 
| 64 | 
         
            +
                            nn.LayerNorm(context_dim, device=device, dtype=dtype),
         
     | 
| 65 | 
         
            +
                            nn.Linear(context_dim, width, device=device, dtype=dtype),
         
     | 
| 66 | 
         
            +
                        )
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def forward(self,
         
     | 
| 71 | 
         
            +
                            model_input: torch.FloatTensor,
         
     | 
| 72 | 
         
            +
                            timestep: torch.LongTensor,
         
     | 
| 73 | 
         
            +
                            context: torch.FloatTensor):
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    r"""
         
     | 
| 76 | 
         
            +
                    Args:
         
     | 
| 77 | 
         
            +
                        model_input (torch.FloatTensor): [bs, n_data, c]
         
     | 
| 78 | 
         
            +
                        timestep (torch.LongTensor): [bs,]
         
     | 
| 79 | 
         
            +
                        context (torch.FloatTensor): [bs, context_tokens, c]
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    Returns:
         
     | 
| 82 | 
         
            +
                        sample (torch.FloatTensor): [bs, n_data, c]
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    """
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    _, n_data, _ = model_input.shape
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # 1. time
         
     | 
| 89 | 
         
            +
                    t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # 2. conditions projector
         
     | 
| 92 | 
         
            +
                    context = self.context_embed(context)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    # 3. denoiser
         
     | 
| 95 | 
         
            +
                    x = self.input_proj(model_input)
         
     | 
| 96 | 
         
            +
                    x = torch.cat([t_emb, context, x], dim=1)
         
     | 
| 97 | 
         
            +
                    x = self.backbone(x)
         
     | 
| 98 | 
         
            +
                    x = self.ln_post(x)
         
     | 
| 99 | 
         
            +
                    x = x[:, -n_data:]
         
     | 
| 100 | 
         
            +
                    sample = self.output_proj(x)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    return sample
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/base.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class BaseDenoiser(nn.Module):
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                def __init__(self):
         
     | 
| 10 | 
         
            +
                    super().__init__()
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def forward(self, x, t, context):
         
     | 
| 13 | 
         
            +
                    raise NotImplementedError
         
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py
    ADDED
    
    | 
         @@ -0,0 +1,393 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from omegaconf import DictConfig
         
     | 
| 4 | 
         
            +
            from typing import List, Tuple, Dict, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
            from torch.optim import lr_scheduler
         
     | 
| 10 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 11 | 
         
            +
            from pytorch_lightning.utilities import rank_zero_only
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from diffusers.schedulers import (
         
     | 
| 14 | 
         
            +
                DDPMScheduler,
         
     | 
| 15 | 
         
            +
                DDIMScheduler,
         
     | 
| 16 | 
         
            +
                KarrasVeScheduler,
         
     | 
| 17 | 
         
            +
                DPMSolverMultistepScheduler
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ...utils import instantiate_from_config
         
     | 
| 21 | 
         
            +
            from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
         
     | 
| 22 | 
         
            +
            from .inference_utils import ddim_sample
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def disabled_train(self, mode=True):
         
     | 
| 28 | 
         
            +
                """Overwrite model.train with this function to make sure train/eval mode
         
     | 
| 29 | 
         
            +
                does not change anymore."""
         
     | 
| 30 | 
         
            +
                return self
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class ClipASLDiffuser(pl.LightningModule):
         
     | 
| 34 | 
         
            +
                first_stage_model: Optional[AlignedShapeAsLatentPLModule]
         
     | 
| 35 | 
         
            +
                cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
         
     | 
| 36 | 
         
            +
                model: nn.Module
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __init__(self, *,
         
     | 
| 39 | 
         
            +
                             first_stage_config,
         
     | 
| 40 | 
         
            +
                             cond_stage_config,
         
     | 
| 41 | 
         
            +
                             denoiser_cfg,
         
     | 
| 42 | 
         
            +
                             scheduler_cfg,
         
     | 
| 43 | 
         
            +
                             optimizer_cfg,
         
     | 
| 44 | 
         
            +
                             loss_cfg,
         
     | 
| 45 | 
         
            +
                             first_stage_key: str = "surface",
         
     | 
| 46 | 
         
            +
                             cond_stage_key: str = "image",
         
     | 
| 47 | 
         
            +
                             scale_by_std: bool = False,
         
     | 
| 48 | 
         
            +
                             z_scale_factor: float = 1.0,
         
     | 
| 49 | 
         
            +
                             ckpt_path: Optional[str] = None,
         
     | 
| 50 | 
         
            +
                             ignore_keys: Union[Tuple[str], List[str]] = ()):
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    super().__init__()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.first_stage_key = first_stage_key
         
     | 
| 55 | 
         
            +
                    self.cond_stage_key = cond_stage_key
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    # 1. lazy initialize first stage
         
     | 
| 58 | 
         
            +
                    self.instantiate_first_stage(first_stage_config)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    # 2. initialize conditional stage
         
     | 
| 61 | 
         
            +
                    self.instantiate_cond_stage(cond_stage_config)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # 3. diffusion model
         
     | 
| 64 | 
         
            +
                    self.model = instantiate_from_config(
         
     | 
| 65 | 
         
            +
                        denoiser_cfg, device=None, dtype=None
         
     | 
| 66 | 
         
            +
                    )
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.optimizer_cfg = optimizer_cfg
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # 4. scheduling strategy
         
     | 
| 71 | 
         
            +
                    self.scheduler_cfg = scheduler_cfg
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
         
     | 
| 74 | 
         
            +
                    self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # 5. loss configures
         
     | 
| 77 | 
         
            +
                    self.loss_cfg = loss_cfg
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    self.scale_by_std = scale_by_std
         
     | 
| 80 | 
         
            +
                    if scale_by_std:
         
     | 
| 81 | 
         
            +
                        self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
         
     | 
| 82 | 
         
            +
                    else:
         
     | 
| 83 | 
         
            +
                        self.z_scale_factor = z_scale_factor
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    self.ckpt_path = ckpt_path
         
     | 
| 86 | 
         
            +
                    if ckpt_path is not None:
         
     | 
| 87 | 
         
            +
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def instantiate_non_trainable_model(self, config):
         
     | 
| 90 | 
         
            +
                    model = instantiate_from_config(config)
         
     | 
| 91 | 
         
            +
                    model = model.eval()
         
     | 
| 92 | 
         
            +
                    model.train = disabled_train
         
     | 
| 93 | 
         
            +
                    for param in model.parameters():
         
     | 
| 94 | 
         
            +
                        param.requires_grad = False
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    return model
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def instantiate_first_stage(self, first_stage_config):
         
     | 
| 99 | 
         
            +
                    self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
         
     | 
| 100 | 
         
            +
                    self.first_stage_model.set_shape_model_only()
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def instantiate_cond_stage(self, cond_stage_config):
         
     | 
| 103 | 
         
            +
                    self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def init_from_ckpt(self, path, ignore_keys=()):
         
     | 
| 106 | 
         
            +
                    state_dict = torch.load(path, map_location="cpu")["state_dict"]
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    keys = list(state_dict.keys())
         
     | 
| 109 | 
         
            +
                    for k in keys:
         
     | 
| 110 | 
         
            +
                        for ik in ignore_keys:
         
     | 
| 111 | 
         
            +
                            if k.startswith(ik):
         
     | 
| 112 | 
         
            +
                                print("Deleting key {} from state_dict.".format(k))
         
     | 
| 113 | 
         
            +
                                del state_dict[k]
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    missing, unexpected = self.load_state_dict(state_dict, strict=False)
         
     | 
| 116 | 
         
            +
                    print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
         
     | 
| 117 | 
         
            +
                    if len(missing) > 0:
         
     | 
| 118 | 
         
            +
                        print(f"Missing Keys: {missing}")
         
     | 
| 119 | 
         
            +
                        print(f"Unexpected Keys: {unexpected}")
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                @property
         
     | 
| 122 | 
         
            +
                def zero_rank(self):
         
     | 
| 123 | 
         
            +
                    if self._trainer:
         
     | 
| 124 | 
         
            +
                        zero_rank = self.trainer.local_rank == 0
         
     | 
| 125 | 
         
            +
                    else:
         
     | 
| 126 | 
         
            +
                        zero_rank = True
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    return zero_rank
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                def configure_optimizers(self) -> Tuple[List, List]:
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    lr = self.learning_rate
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    trainable_parameters = list(self.model.parameters())
         
     | 
| 135 | 
         
            +
                    if self.optimizer_cfg is None:
         
     | 
| 136 | 
         
            +
                        optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
         
     | 
| 137 | 
         
            +
                        schedulers = []
         
     | 
| 138 | 
         
            +
                    else:
         
     | 
| 139 | 
         
            +
                        optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
         
     | 
| 140 | 
         
            +
                        scheduler_func = instantiate_from_config(
         
     | 
| 141 | 
         
            +
                            self.optimizer_cfg.scheduler,
         
     | 
| 142 | 
         
            +
                            max_decay_steps=self.trainer.max_steps,
         
     | 
| 143 | 
         
            +
                            lr_max=lr
         
     | 
| 144 | 
         
            +
                        )
         
     | 
| 145 | 
         
            +
                        scheduler = {
         
     | 
| 146 | 
         
            +
                            "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
         
     | 
| 147 | 
         
            +
                            "interval": "step",
         
     | 
| 148 | 
         
            +
                            "frequency": 1
         
     | 
| 149 | 
         
            +
                        }
         
     | 
| 150 | 
         
            +
                        optimizers = [optimizer]
         
     | 
| 151 | 
         
            +
                        schedulers = [scheduler]
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    return optimizers, schedulers
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                @torch.no_grad()
         
     | 
| 156 | 
         
            +
                def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    z_q = self.first_stage_model.encode(surface, sample_posterior)
         
     | 
| 159 | 
         
            +
                    z_q = self.z_scale_factor * z_q
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    return z_q
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                @torch.no_grad()
         
     | 
| 164 | 
         
            +
                def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    z_q = 1. / self.z_scale_factor * z_q
         
     | 
| 167 | 
         
            +
                    latents = self.first_stage_model.decode(z_q, **kwargs)
         
     | 
| 168 | 
         
            +
                    return latents
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                @rank_zero_only
         
     | 
| 171 | 
         
            +
                @torch.no_grad()
         
     | 
| 172 | 
         
            +
                def on_train_batch_start(self, batch, batch_idx):
         
     | 
| 173 | 
         
            +
                    # only for very first batch
         
     | 
| 174 | 
         
            +
                    if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
         
     | 
| 175 | 
         
            +
                            and batch_idx == 0 and self.ckpt_path is None:
         
     | 
| 176 | 
         
            +
                        # set rescale weight to 1./std of encodings
         
     | 
| 177 | 
         
            +
                        print("### USING STD-RESCALING ###")
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                        z_q = self.encode_first_stage(batch[self.first_stage_key])
         
     | 
| 180 | 
         
            +
                        z = z_q.detach()
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                        del self.z_scale_factor
         
     | 
| 183 | 
         
            +
                        self.register_buffer("z_scale_factor", 1. / z.flatten().std())
         
     | 
| 184 | 
         
            +
                        print(f"setting self.z_scale_factor to {self.z_scale_factor}")
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                        print("### USING STD-RESCALING ###")
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                def compute_loss(self, model_outputs, split):
         
     | 
| 189 | 
         
            +
                    """
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    Args:
         
     | 
| 192 | 
         
            +
                        model_outputs (dict):
         
     | 
| 193 | 
         
            +
                            - x_0:
         
     | 
| 194 | 
         
            +
                            - noise:
         
     | 
| 195 | 
         
            +
                            - noise_prior:
         
     | 
| 196 | 
         
            +
                            - noise_pred:
         
     | 
| 197 | 
         
            +
                            - noise_pred_prior:
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                        split (str):
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    Returns:
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    """
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    pred = model_outputs["pred"]
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    if self.noise_scheduler.prediction_type == "epsilon":
         
     | 
| 208 | 
         
            +
                        target = model_outputs["noise"]
         
     | 
| 209 | 
         
            +
                    elif self.noise_scheduler.prediction_type == "sample":
         
     | 
| 210 | 
         
            +
                        target = model_outputs["x_0"]
         
     | 
| 211 | 
         
            +
                    else:
         
     | 
| 212 | 
         
            +
                        raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    if self.loss_cfg.loss_type == "l1":
         
     | 
| 215 | 
         
            +
                        simple = F.l1_loss(pred, target, reduction="mean")
         
     | 
| 216 | 
         
            +
                    elif self.loss_cfg.loss_type in ["mse", "l2"]:
         
     | 
| 217 | 
         
            +
                        simple = F.mse_loss(pred, target, reduction="mean")
         
     | 
| 218 | 
         
            +
                    else:
         
     | 
| 219 | 
         
            +
                        raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    total_loss = simple
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    loss_dict = {
         
     | 
| 224 | 
         
            +
                        f"{split}/total_loss": total_loss.clone().detach(),
         
     | 
| 225 | 
         
            +
                        f"{split}/simple": simple.detach(),
         
     | 
| 226 | 
         
            +
                    }
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    return total_loss, loss_dict
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def forward(self, batch):
         
     | 
| 231 | 
         
            +
                    """
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    Args:
         
     | 
| 234 | 
         
            +
                        batch:
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    Returns:
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    """
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    latents = self.encode_first_stage(batch[self.first_stage_key])
         
     | 
| 241 | 
         
            +
                    conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    # Sample noise that we"ll add to the latents
         
     | 
| 244 | 
         
            +
                    # [batch_size, n_token, latent_dim]
         
     | 
| 245 | 
         
            +
                    noise = torch.randn_like(latents)
         
     | 
| 246 | 
         
            +
                    bs = latents.shape[0]
         
     | 
| 247 | 
         
            +
                    # Sample a random timestep for each motion
         
     | 
| 248 | 
         
            +
                    timesteps = torch.randint(
         
     | 
| 249 | 
         
            +
                        0,
         
     | 
| 250 | 
         
            +
                        self.noise_scheduler.config.num_train_timesteps,
         
     | 
| 251 | 
         
            +
                        (bs,),
         
     | 
| 252 | 
         
            +
                        device=latents.device,
         
     | 
| 253 | 
         
            +
                    )
         
     | 
| 254 | 
         
            +
                    timesteps = timesteps.long()
         
     | 
| 255 | 
         
            +
                    # Add noise to the latents according to the noise magnitude at each timestep
         
     | 
| 256 | 
         
            +
                    noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    # diffusion model forward
         
     | 
| 259 | 
         
            +
                    noise_pred = self.model(noisy_z, timesteps, conditions)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    diffusion_outputs = {
         
     | 
| 262 | 
         
            +
                        "x_0": noisy_z,
         
     | 
| 263 | 
         
            +
                        "noise": noise,
         
     | 
| 264 | 
         
            +
                        "pred": noise_pred
         
     | 
| 265 | 
         
            +
                    }
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    return diffusion_outputs
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
         
     | 
| 270 | 
         
            +
                                  batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
         
     | 
| 271 | 
         
            +
                    """
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    Args:
         
     | 
| 274 | 
         
            +
                        batch (dict): the batch sample, and it contains:
         
     | 
| 275 | 
         
            +
                            - surface (torch.FloatTensor):
         
     | 
| 276 | 
         
            +
                            - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
         
     | 
| 277 | 
         
            +
                            - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
         
     | 
| 278 | 
         
            +
                            - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
         
     | 
| 279 | 
         
            +
                            - text (list of str):
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                        batch_idx (int):
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                        optimizer_idx (int):
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    Returns:
         
     | 
| 286 | 
         
            +
                        loss (torch.FloatTensor):
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    """
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    diffusion_outputs = self(batch)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
         
     | 
| 293 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    return loss
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                def validation_step(self, batch: Dict[str, torch.FloatTensor],
         
     | 
| 298 | 
         
            +
                                    batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
         
     | 
| 299 | 
         
            +
                    """
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    Args:
         
     | 
| 302 | 
         
            +
                        batch (dict): the batch sample, and it contains:
         
     | 
| 303 | 
         
            +
                            - surface_pc (torch.FloatTensor): [n_pts, 4]
         
     | 
| 304 | 
         
            +
                            - surface_feats (torch.FloatTensor): [n_pts, c]
         
     | 
| 305 | 
         
            +
                            - text (list of str):
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                        batch_idx (int):
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                        optimizer_idx (int):
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    Returns:
         
     | 
| 312 | 
         
            +
                        loss (torch.FloatTensor):
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    """
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    diffusion_outputs = self(batch)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
         
     | 
| 319 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    return loss
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                @torch.no_grad()
         
     | 
| 324 | 
         
            +
                def sample(self,
         
     | 
| 325 | 
         
            +
                           batch: Dict[str, Union[torch.FloatTensor, List[str]]],
         
     | 
| 326 | 
         
            +
                           sample_times: int = 1,
         
     | 
| 327 | 
         
            +
                           steps: Optional[int] = None,
         
     | 
| 328 | 
         
            +
                           guidance_scale: Optional[float] = None,
         
     | 
| 329 | 
         
            +
                           eta: float = 0.0,
         
     | 
| 330 | 
         
            +
                           return_intermediates: bool = False, **kwargs):
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    if steps is None:
         
     | 
| 333 | 
         
            +
                        steps = self.scheduler_cfg.num_inference_steps
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    if guidance_scale is None:
         
     | 
| 336 | 
         
            +
                        guidance_scale = self.scheduler_cfg.guidance_scale
         
     | 
| 337 | 
         
            +
                    do_classifier_free_guidance = guidance_scale > 0
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    # conditional encode
         
     | 
| 340 | 
         
            +
                    xc = batch[self.cond_stage_key]
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                    cond = self.cond_stage_model(xc)
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    if do_classifier_free_guidance:
         
     | 
| 347 | 
         
            +
                        un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
         
     | 
| 348 | 
         
            +
                        cond = torch.cat([un_cond, cond], dim=0)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    outputs = []
         
     | 
| 351 | 
         
            +
                    latents = None
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    if not return_intermediates:
         
     | 
| 354 | 
         
            +
                        for _ in range(sample_times):
         
     | 
| 355 | 
         
            +
                            sample_loop = ddim_sample(
         
     | 
| 356 | 
         
            +
                                self.denoise_scheduler,
         
     | 
| 357 | 
         
            +
                                self.model,
         
     | 
| 358 | 
         
            +
                                shape=self.first_stage_model.latent_shape,
         
     | 
| 359 | 
         
            +
                                cond=cond,
         
     | 
| 360 | 
         
            +
                                steps=steps,
         
     | 
| 361 | 
         
            +
                                guidance_scale=guidance_scale,
         
     | 
| 362 | 
         
            +
                                do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 363 | 
         
            +
                                device=self.device,
         
     | 
| 364 | 
         
            +
                                eta=eta,
         
     | 
| 365 | 
         
            +
                                disable_prog=not self.zero_rank
         
     | 
| 366 | 
         
            +
                            )
         
     | 
| 367 | 
         
            +
                            for sample, t in sample_loop:
         
     | 
| 368 | 
         
            +
                                latents = sample
         
     | 
| 369 | 
         
            +
                            outputs.append(self.decode_first_stage(latents, **kwargs))
         
     | 
| 370 | 
         
            +
                    else:
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                        sample_loop = ddim_sample(
         
     | 
| 373 | 
         
            +
                            self.denoise_scheduler,
         
     | 
| 374 | 
         
            +
                            self.model,
         
     | 
| 375 | 
         
            +
                            shape=self.first_stage_model.latent_shape,
         
     | 
| 376 | 
         
            +
                            cond=cond,
         
     | 
| 377 | 
         
            +
                            steps=steps,
         
     | 
| 378 | 
         
            +
                            guidance_scale=guidance_scale,
         
     | 
| 379 | 
         
            +
                            do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 380 | 
         
            +
                            device=self.device,
         
     | 
| 381 | 
         
            +
                            eta=eta,
         
     | 
| 382 | 
         
            +
                            disable_prog=not self.zero_rank
         
     | 
| 383 | 
         
            +
                        )
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                        iter_size = steps // sample_times
         
     | 
| 386 | 
         
            +
                        i = 0
         
     | 
| 387 | 
         
            +
                        for sample, t in sample_loop:
         
     | 
| 388 | 
         
            +
                            latents = sample
         
     | 
| 389 | 
         
            +
                            if i % iter_size == 0 or i == steps - 1:
         
     | 
| 390 | 
         
            +
                                outputs.append(self.decode_first_stage(latents, **kwargs))
         
     | 
| 391 | 
         
            +
                            i += 1
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                    return outputs
         
     | 
    	
        primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from tqdm import tqdm
         
     | 
| 5 | 
         
            +
            from typing import Tuple, List, Union, Optional
         
     | 
| 6 | 
         
            +
            from diffusers.schedulers import DDIMScheduler
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            __all__ = ["ddim_sample"]
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def ddim_sample(ddim_scheduler: DDIMScheduler,
         
     | 
| 13 | 
         
            +
                            diffusion_model: torch.nn.Module,
         
     | 
| 14 | 
         
            +
                            shape: Union[List[int], Tuple[int]],
         
     | 
| 15 | 
         
            +
                            cond: torch.FloatTensor,
         
     | 
| 16 | 
         
            +
                            steps: int,
         
     | 
| 17 | 
         
            +
                            eta: float = 0.0,
         
     | 
| 18 | 
         
            +
                            guidance_scale: float = 3.0,
         
     | 
| 19 | 
         
            +
                            do_classifier_free_guidance: bool = True,
         
     | 
| 20 | 
         
            +
                            generator: Optional[torch.Generator] = None,
         
     | 
| 21 | 
         
            +
                            device: torch.device = "cuda:0",
         
     | 
| 22 | 
         
            +
                            disable_prog: bool = True):
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                assert steps > 0, f"{steps} must > 0."
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                # init latents
         
     | 
| 27 | 
         
            +
                bsz = cond.shape[0]
         
     | 
| 28 | 
         
            +
                if do_classifier_free_guidance:
         
     | 
| 29 | 
         
            +
                    bsz = bsz // 2
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                latents = torch.randn(
         
     | 
| 32 | 
         
            +
                    (bsz, *shape),
         
     | 
| 33 | 
         
            +
                    generator=generator,
         
     | 
| 34 | 
         
            +
                    device=cond.device,
         
     | 
| 35 | 
         
            +
                    dtype=cond.dtype,
         
     | 
| 36 | 
         
            +
                )
         
     | 
| 37 | 
         
            +
                # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 38 | 
         
            +
                latents = latents * ddim_scheduler.init_noise_sigma
         
     | 
| 39 | 
         
            +
                # set timesteps
         
     | 
| 40 | 
         
            +
                ddim_scheduler.set_timesteps(steps)
         
     | 
| 41 | 
         
            +
                timesteps = ddim_scheduler.timesteps.to(device)
         
     | 
| 42 | 
         
            +
                # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         
     | 
| 43 | 
         
            +
                # eta (η) is only used with the DDIMScheduler, and between [0, 1]
         
     | 
| 44 | 
         
            +
                extra_step_kwargs = {
         
     | 
| 45 | 
         
            +
                    "eta": eta,
         
     | 
| 46 | 
         
            +
                    "generator": generator
         
     | 
| 47 | 
         
            +
                }
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                # reverse
         
     | 
| 50 | 
         
            +
                for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
         
     | 
| 51 | 
         
            +
                    # expand the latents if we are doing classifier free guidance
         
     | 
| 52 | 
         
            +
                    latent_model_input = (
         
     | 
| 53 | 
         
            +
                        torch.cat([latents] * 2)
         
     | 
| 54 | 
         
            +
                        if do_classifier_free_guidance
         
     | 
| 55 | 
         
            +
                        else latents
         
     | 
| 56 | 
         
            +
                    )
         
     | 
| 57 | 
         
            +
                    # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 58 | 
         
            +
                    # predict the noise residual
         
     | 
| 59 | 
         
            +
                    timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
         
     | 
| 60 | 
         
            +
                    timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
         
     | 
| 61 | 
         
            +
                    noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # perform guidance
         
     | 
| 64 | 
         
            +
                    if do_classifier_free_guidance:
         
     | 
| 65 | 
         
            +
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         
     | 
| 66 | 
         
            +
                        noise_pred = noise_pred_uncond + guidance_scale * (
         
     | 
| 67 | 
         
            +
                                noise_pred_text - noise_pred_uncond
         
     | 
| 68 | 
         
            +
                        )
         
     | 
| 69 | 
         
            +
                        # text_embeddings_for_guidance = encoder_hidden_states.chunk(
         
     | 
| 70 | 
         
            +
                        #     2)[1] if do_classifier_free_guidance else encoder_hidden_states
         
     | 
| 71 | 
         
            +
                    # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 72 | 
         
            +
                    latents = ddim_scheduler.step(
         
     | 
| 73 | 
         
            +
                        noise_pred, t, latents, **extra_step_kwargs
         
     | 
| 74 | 
         
            +
                    ).prev_sample
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    yield latents, t
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def karra_sample():
         
     | 
| 80 | 
         
            +
                pass
         
     | 
    	
        primitive_anything/michelangelo/models/conditional_encoders/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .clip import CLIPEncoder
         
     | 
    	
        primitive_anything/michelangelo/models/conditional_encoders/clip.py
    ADDED
    
    | 
         @@ -0,0 +1,89 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 7 | 
         
            +
            from torchvision.transforms import Normalize
         
     | 
| 8 | 
         
            +
            from transformers import CLIPModel, CLIPTokenizer
         
     | 
| 9 | 
         
            +
            from transformers.utils import ModelOutput
         
     | 
| 10 | 
         
            +
            from typing import Iterable, Optional, Union, List
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            @dataclass
         
     | 
| 17 | 
         
            +
            class CLIPEmbedOutput(ModelOutput):
         
     | 
| 18 | 
         
            +
                last_hidden_state: torch.FloatTensor = None
         
     | 
| 19 | 
         
            +
                pooler_output: torch.FloatTensor = None
         
     | 
| 20 | 
         
            +
                embeds: torch.FloatTensor = None
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class CLIPEncoder(torch.nn.Module):
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def __init__(self, model_path="openai/clip-vit-base-patch32"):
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    super().__init__()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    # Load the CLIP model and processor
         
     | 
| 30 | 
         
            +
                    self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
         
     | 
| 31 | 
         
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
         
     | 
| 32 | 
         
            +
                    self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.model.training = False
         
     | 
| 35 | 
         
            +
                    for p in self.model.parameters():
         
     | 
| 36 | 
         
            +
                        p.requires_grad = False
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                @torch.no_grad()
         
     | 
| 39 | 
         
            +
                def encode_image(self, images: Iterable[Optional[ImageType]]):
         
     | 
| 40 | 
         
            +
                    pixel_values = self.image_preprocess(images)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    vision_outputs = self.model.vision_model(pixel_values=pixel_values)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    pooler_output = vision_outputs[1]  # pooled_output
         
     | 
| 45 | 
         
            +
                    image_features = self.model.visual_projection(pooler_output)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    visual_embeds = CLIPEmbedOutput(
         
     | 
| 48 | 
         
            +
                        last_hidden_state=vision_outputs.last_hidden_state,
         
     | 
| 49 | 
         
            +
                        pooler_output=pooler_output,
         
     | 
| 50 | 
         
            +
                        embeds=image_features
         
     | 
| 51 | 
         
            +
                    )
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    return visual_embeds
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                @torch.no_grad()
         
     | 
| 56 | 
         
            +
                def encode_text(self, texts: List[str]):
         
     | 
| 57 | 
         
            +
                    text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    text_outputs = self.model.text_model(input_ids=text_inputs)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    pooler_output = text_outputs[1]  # pooled_output
         
     | 
| 62 | 
         
            +
                    text_features = self.model.text_projection(pooler_output)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    text_embeds = CLIPEmbedOutput(
         
     | 
| 65 | 
         
            +
                        last_hidden_state=text_outputs.last_hidden_state,
         
     | 
| 66 | 
         
            +
                        pooler_output=pooler_output,
         
     | 
| 67 | 
         
            +
                        embeds=text_features
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    return text_embeds
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def forward(self,
         
     | 
| 73 | 
         
            +
                            images: Iterable[Optional[ImageType]],
         
     | 
| 74 | 
         
            +
                            texts: List[str]):
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    visual_embeds = self.encode_image(images)
         
     | 
| 77 | 
         
            +
                    text_embeds = self.encode_text(texts)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    return visual_embeds, text_embeds
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
    	
        primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py
    ADDED
    
    | 
         @@ -0,0 +1,562 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from torchvision import transforms
         
     | 
| 7 | 
         
            +
            from transformers import CLIPModel, CLIPTokenizer
         
     | 
| 8 | 
         
            +
            from collections import OrderedDict
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from ...data.transforms import RandomResize
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class AbstractEncoder(nn.Module):
         
     | 
| 14 | 
         
            +
                embedding_dim: int
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self):
         
     | 
| 17 | 
         
            +
                    super().__init__()
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def encode(self, *args, **kwargs):
         
     | 
| 20 | 
         
            +
                    raise NotImplementedError
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class ClassEmbedder(nn.Module):
         
     | 
| 24 | 
         
            +
                def __init__(self, embed_dim, n_classes=1000, key="class"):
         
     | 
| 25 | 
         
            +
                    super().__init__()
         
     | 
| 26 | 
         
            +
                    self.key = key
         
     | 
| 27 | 
         
            +
                    self.embedding = nn.Embedding(n_classes, embed_dim)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def forward(self, batch, key=None):
         
     | 
| 30 | 
         
            +
                    if key is None:
         
     | 
| 31 | 
         
            +
                        key = self.key
         
     | 
| 32 | 
         
            +
                    # this is for use in crossattn
         
     | 
| 33 | 
         
            +
                    c = batch[key][:, None]
         
     | 
| 34 | 
         
            +
                    c = self.embedding(c)
         
     | 
| 35 | 
         
            +
                    return c
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class FrozenCLIPTextEmbedder(AbstractEncoder):
         
     | 
| 39 | 
         
            +
                """Uses the CLIP transformer encoder for text (from Hugging Face)"""
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(
         
     | 
| 42 | 
         
            +
                    self,
         
     | 
| 43 | 
         
            +
                    version="openai/clip-vit-large-patch14",
         
     | 
| 44 | 
         
            +
                    tokenizer_version=None,
         
     | 
| 45 | 
         
            +
                    device="cuda",
         
     | 
| 46 | 
         
            +
                    max_length=77,
         
     | 
| 47 | 
         
            +
                    zero_embedding_radio: float = 0.1,
         
     | 
| 48 | 
         
            +
                ):
         
     | 
| 49 | 
         
            +
                    super().__init__()
         
     | 
| 50 | 
         
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    self.device = device
         
     | 
| 53 | 
         
            +
                    self.max_length = max_length
         
     | 
| 54 | 
         
            +
                    self.zero_embedding_radio = zero_embedding_radio
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    self.clip_dict = OrderedDict()
         
     | 
| 57 | 
         
            +
                    self.clip_name = os.path.split(version)[-1]
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    transformer = CLIPModel.from_pretrained(version).text_model
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    for param in transformer.parameters():
         
     | 
| 62 | 
         
            +
                        param.requires_grad = False
         
     | 
| 63 | 
         
            +
                    self.clip_dict[self.clip_name] = transformer
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    self._move_flag = False
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                @property
         
     | 
| 68 | 
         
            +
                def clip(self):
         
     | 
| 69 | 
         
            +
                    return self.clip_dict[self.clip_name]
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def move(self):
         
     | 
| 72 | 
         
            +
                    if self._move_flag:
         
     | 
| 73 | 
         
            +
                        return
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
         
     | 
| 76 | 
         
            +
                    self._move_flag = True
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def unconditional_embedding(self, batch_size):
         
     | 
| 79 | 
         
            +
                    empty_text = [""] * batch_size
         
     | 
| 80 | 
         
            +
                    empty_z = self.forward(empty_text)
         
     | 
| 81 | 
         
            +
                    return empty_z
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def forward(self, text):
         
     | 
| 84 | 
         
            +
                    self.move()
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    batch_encoding = self.tokenizer(
         
     | 
| 87 | 
         
            +
                        text,
         
     | 
| 88 | 
         
            +
                        truncation=True,
         
     | 
| 89 | 
         
            +
                        max_length=self.max_length,
         
     | 
| 90 | 
         
            +
                        return_length=True,
         
     | 
| 91 | 
         
            +
                        return_overflowing_tokens=False,
         
     | 
| 92 | 
         
            +
                        padding="max_length",
         
     | 
| 93 | 
         
            +
                        return_tensors="pt",
         
     | 
| 94 | 
         
            +
                    )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    tokens = batch_encoding["input_ids"].to(self.device)
         
     | 
| 97 | 
         
            +
                    outputs = self.clip(input_ids=tokens)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    z = outputs.last_hidden_state
         
     | 
| 100 | 
         
            +
                    return z
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def encode(self, text):
         
     | 
| 103 | 
         
            +
                    batch_size = len(text)
         
     | 
| 104 | 
         
            +
                    batch_mask = torch.rand((batch_size,))
         
     | 
| 105 | 
         
            +
                    for i in range(batch_size):
         
     | 
| 106 | 
         
            +
                        if batch_mask[i] < self.zero_embedding_radio:
         
     | 
| 107 | 
         
            +
                            text[i] = ""
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    return self(text)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
         
     | 
| 112 | 
         
            +
                """Uses the CLIP transformer encoder for text (from Hugging Face)"""
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                def __init__(
         
     | 
| 115 | 
         
            +
                    self,
         
     | 
| 116 | 
         
            +
                    version="openai/clip-vit-large-patch14",
         
     | 
| 117 | 
         
            +
                    tokenizer_version=None,
         
     | 
| 118 | 
         
            +
                    device="cuda",
         
     | 
| 119 | 
         
            +
                    max_length=77,
         
     | 
| 120 | 
         
            +
                    zero_embedding_radio: float = 0.1,
         
     | 
| 121 | 
         
            +
                ):
         
     | 
| 122 | 
         
            +
                    super().__init__()
         
     | 
| 123 | 
         
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.device = device
         
     | 
| 126 | 
         
            +
                    self.max_length = max_length
         
     | 
| 127 | 
         
            +
                    self.zero_embedding_radio = zero_embedding_radio
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    self.clip_dict = OrderedDict()
         
     | 
| 130 | 
         
            +
                    self.clip_name = os.path.split(version)[-1]
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    transformer = CLIPModel.from_pretrained(version).text_model
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    for param in transformer.parameters():
         
     | 
| 135 | 
         
            +
                        param.requires_grad = False
         
     | 
| 136 | 
         
            +
                    self.clip_dict[self.clip_name] = transformer
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    self._move_flag = False
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                @property
         
     | 
| 141 | 
         
            +
                def clip(self):
         
     | 
| 142 | 
         
            +
                    return self.clip_dict[self.clip_name]
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def move(self):
         
     | 
| 145 | 
         
            +
                    if self._move_flag:
         
     | 
| 146 | 
         
            +
                        return
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
         
     | 
| 149 | 
         
            +
                    self._move_flag = True
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def unconditional_embedding(self, batch_size):
         
     | 
| 152 | 
         
            +
                    empty_text = [""] * batch_size
         
     | 
| 153 | 
         
            +
                    empty_z = self.forward(empty_text)
         
     | 
| 154 | 
         
            +
                    return empty_z
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def forward(self, text):
         
     | 
| 157 | 
         
            +
                    self.move()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    batch_encoding = self.tokenizer(
         
     | 
| 160 | 
         
            +
                        text,
         
     | 
| 161 | 
         
            +
                        truncation=True,
         
     | 
| 162 | 
         
            +
                        max_length=self.max_length,
         
     | 
| 163 | 
         
            +
                        return_length=True,
         
     | 
| 164 | 
         
            +
                        return_overflowing_tokens=False,
         
     | 
| 165 | 
         
            +
                        padding="max_length",
         
     | 
| 166 | 
         
            +
                        return_tensors="pt",
         
     | 
| 167 | 
         
            +
                    )
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    tokens = batch_encoding["input_ids"].to(self.device)
         
     | 
| 170 | 
         
            +
                    outputs = self.clip(input_ids=tokens)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    z = outputs.last_hidden_state
         
     | 
| 173 | 
         
            +
                    return z
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                def encode(self, text):
         
     | 
| 176 | 
         
            +
                    batch_size = len(text)
         
     | 
| 177 | 
         
            +
                    batch_mask = torch.rand((batch_size,))
         
     | 
| 178 | 
         
            +
                    for i in range(batch_size):
         
     | 
| 179 | 
         
            +
                        if batch_mask[i] < self.zero_embedding_radio:
         
     | 
| 180 | 
         
            +
                            text[i] = ""
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    return self(text)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
            class FrozenCLIPImageEmbedder(AbstractEncoder):
         
     | 
| 186 | 
         
            +
                """Uses the CLIP transformer encoder for text (from Hugging Face)"""
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                def __init__(
         
     | 
| 189 | 
         
            +
                        self,
         
     | 
| 190 | 
         
            +
                        version="openai/clip-vit-large-patch14",
         
     | 
| 191 | 
         
            +
                        device="cuda",
         
     | 
| 192 | 
         
            +
                        zero_embedding_radio=0.1,
         
     | 
| 193 | 
         
            +
                        normalize_embedding=True,
         
     | 
| 194 | 
         
            +
                        num_projection_vector=0,
         
     | 
| 195 | 
         
            +
                        linear_mapping_bias=True,
         
     | 
| 196 | 
         
            +
                        reverse_visual_projection=False,
         
     | 
| 197 | 
         
            +
                ):
         
     | 
| 198 | 
         
            +
                    super().__init__()
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    self.device = device
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    self.clip_dict = OrderedDict()
         
     | 
| 203 | 
         
            +
                    self.clip_name = os.path.split(version)[-1]
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    clip_model = CLIPModel.from_pretrained(version)
         
     | 
| 206 | 
         
            +
                    clip_model.text_model = None
         
     | 
| 207 | 
         
            +
                    clip_model.text_projection = None
         
     | 
| 208 | 
         
            +
                    clip_model = clip_model.eval()
         
     | 
| 209 | 
         
            +
                    for param in self.parameters():
         
     | 
| 210 | 
         
            +
                        param.requires_grad = False
         
     | 
| 211 | 
         
            +
                    self.clip_dict[self.clip_name] = clip_model
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    self.transform = transforms.Compose(
         
     | 
| 214 | 
         
            +
                        [
         
     | 
| 215 | 
         
            +
                            transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
         
     | 
| 216 | 
         
            +
                            transforms.CenterCrop(224),  # crop a (224, 224) square
         
     | 
| 217 | 
         
            +
                            transforms.Normalize(
         
     | 
| 218 | 
         
            +
                                mean=[0.48145466, 0.4578275, 0.40821073],
         
     | 
| 219 | 
         
            +
                                std=[0.26862954, 0.26130258, 0.27577711],
         
     | 
| 220 | 
         
            +
                            ),
         
     | 
| 221 | 
         
            +
                        ]
         
     | 
| 222 | 
         
            +
                    )
         
     | 
| 223 | 
         
            +
                    self.zero_embedding_radio = zero_embedding_radio
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    self.num_projection_vector = num_projection_vector
         
     | 
| 226 | 
         
            +
                    self.reverse_visual_projection = reverse_visual_projection
         
     | 
| 227 | 
         
            +
                    self.normalize_embedding = normalize_embedding
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    embedding_dim = (
         
     | 
| 230 | 
         
            +
                        clip_model.visual_projection.in_features
         
     | 
| 231 | 
         
            +
                        if reverse_visual_projection
         
     | 
| 232 | 
         
            +
                        else clip_model.visual_projection.out_features
         
     | 
| 233 | 
         
            +
                    )
         
     | 
| 234 | 
         
            +
                    self.embedding_dim = embedding_dim
         
     | 
| 235 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 236 | 
         
            +
                        self.projection = nn.Linear(
         
     | 
| 237 | 
         
            +
                            embedding_dim,
         
     | 
| 238 | 
         
            +
                            clip_model.visual_projection.out_features * num_projection_vector,
         
     | 
| 239 | 
         
            +
                            bias=linear_mapping_bias,
         
     | 
| 240 | 
         
            +
                        )
         
     | 
| 241 | 
         
            +
                        nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    self._move_flag = False
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                @property
         
     | 
| 246 | 
         
            +
                def clip(self):
         
     | 
| 247 | 
         
            +
                    return self.clip_dict[self.clip_name]
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                def unconditional_embedding(self, batch_size):
         
     | 
| 250 | 
         
            +
                    zero = torch.zeros(
         
     | 
| 251 | 
         
            +
                        batch_size,
         
     | 
| 252 | 
         
            +
                        1,
         
     | 
| 253 | 
         
            +
                        self.embedding_dim,
         
     | 
| 254 | 
         
            +
                        device=self.device,
         
     | 
| 255 | 
         
            +
                        dtype=self.clip.visual_projection.weight.dtype,
         
     | 
| 256 | 
         
            +
                    )
         
     | 
| 257 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 258 | 
         
            +
                        zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
         
     | 
| 259 | 
         
            +
                    return zero
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
         
     | 
| 262 | 
         
            +
                    if value_range is not None:
         
     | 
| 263 | 
         
            +
                        low, high = value_range
         
     | 
| 264 | 
         
            +
                        image = (image - low) / (high - low)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    if self.reverse_visual_projection:
         
     | 
| 269 | 
         
            +
                        z = self.clip.vision_model(self.transform(image))[1]
         
     | 
| 270 | 
         
            +
                    else:
         
     | 
| 271 | 
         
            +
                        z = self.clip.get_image_features(self.transform(image))
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    if self.normalize_embedding:
         
     | 
| 274 | 
         
            +
                        z = z / z.norm(dim=-1, keepdim=True)
         
     | 
| 275 | 
         
            +
                    if z.ndim == 2:
         
     | 
| 276 | 
         
            +
                        z = z.unsqueeze(dim=-2)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    if zero_embedding_radio > 0:
         
     | 
| 279 | 
         
            +
                        mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
         
     | 
| 280 | 
         
            +
                        z = z * mask.to(z)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 283 | 
         
            +
                        z = self.projection(z).view(len(image), self.num_projection_vector, -1)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    return z
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                def move(self):
         
     | 
| 288 | 
         
            +
                    if self._move_flag:
         
     | 
| 289 | 
         
            +
                        return
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
         
     | 
| 292 | 
         
            +
                    self._move_flag = True
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                def encode(self, image):
         
     | 
| 295 | 
         
            +
                    self.move()
         
     | 
| 296 | 
         
            +
                    return self(image, zero_embedding_radio=self.zero_embedding_radio)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            class FrozenCLIPImageGridEmbedder(AbstractEncoder):
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def __init__(
         
     | 
| 302 | 
         
            +
                        self,
         
     | 
| 303 | 
         
            +
                        version="openai/clip-vit-large-patch14",
         
     | 
| 304 | 
         
            +
                        device="cuda",
         
     | 
| 305 | 
         
            +
                        zero_embedding_radio=0.1,
         
     | 
| 306 | 
         
            +
                ):
         
     | 
| 307 | 
         
            +
                    super().__init__()
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    self.device = device
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    self.clip_dict = OrderedDict()
         
     | 
| 312 | 
         
            +
                    self.clip_name = os.path.split(version)[-1]
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    clip_model: CLIPModel = CLIPModel.from_pretrained(version)
         
     | 
| 315 | 
         
            +
                    clip_model.text_model = None
         
     | 
| 316 | 
         
            +
                    clip_model.text_projection = None
         
     | 
| 317 | 
         
            +
                    clip_model = clip_model.eval()
         
     | 
| 318 | 
         
            +
                    for param in self.parameters():
         
     | 
| 319 | 
         
            +
                        param.requires_grad = False
         
     | 
| 320 | 
         
            +
                    self.clip_dict[self.clip_name] = clip_model
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    self.transform = transforms.Compose(
         
     | 
| 323 | 
         
            +
                        [
         
     | 
| 324 | 
         
            +
                            transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
         
     | 
| 325 | 
         
            +
                            transforms.CenterCrop(224),  # crop a (224, 224) square
         
     | 
| 326 | 
         
            +
                            transforms.Normalize(
         
     | 
| 327 | 
         
            +
                                mean=[0.48145466, 0.4578275, 0.40821073],
         
     | 
| 328 | 
         
            +
                                std=[0.26862954, 0.26130258, 0.27577711],
         
     | 
| 329 | 
         
            +
                            ),
         
     | 
| 330 | 
         
            +
                        ]
         
     | 
| 331 | 
         
            +
                    )
         
     | 
| 332 | 
         
            +
                    self.zero_embedding_radio = zero_embedding_radio
         
     | 
| 333 | 
         
            +
                    self.embedding_dim = clip_model.vision_embed_dim
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    self._move_flag = False
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                @property
         
     | 
| 338 | 
         
            +
                def clip(self):
         
     | 
| 339 | 
         
            +
                    return self.clip_dict[self.clip_name]
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                def move(self):
         
     | 
| 342 | 
         
            +
                    if self._move_flag:
         
     | 
| 343 | 
         
            +
                        return
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
         
     | 
| 346 | 
         
            +
                    self._move_flag = True
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                def unconditional_embedding(self, batch_size):
         
     | 
| 349 | 
         
            +
                    zero = torch.zeros(
         
     | 
| 350 | 
         
            +
                        batch_size,
         
     | 
| 351 | 
         
            +
                        self.clip.vision_model.embeddings.num_positions,
         
     | 
| 352 | 
         
            +
                        self.embedding_dim,
         
     | 
| 353 | 
         
            +
                        device=self.device,
         
     | 
| 354 | 
         
            +
                        dtype=self.clip.visual_projection.weight.dtype,
         
     | 
| 355 | 
         
            +
                    )
         
     | 
| 356 | 
         
            +
                    return zero
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
         
     | 
| 359 | 
         
            +
                    self.move()
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    if value_range is not None:
         
     | 
| 362 | 
         
            +
                        low, high = value_range
         
     | 
| 363 | 
         
            +
                        image = (image - low) / (high - low)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    z = self.clip.vision_model(self.transform(image)).last_hidden_state
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    if zero_embedding_radio > 0:
         
     | 
| 370 | 
         
            +
                        mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
         
     | 
| 371 | 
         
            +
                        z = z * mask.to(z)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    return z
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                def encode(self, image):
         
     | 
| 376 | 
         
            +
                    return self(image, zero_embedding_radio=self.zero_embedding_radio)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            class MoECLIPImageEncoder(nn.Module):
         
     | 
| 380 | 
         
            +
                def __init__(
         
     | 
| 381 | 
         
            +
                        self,
         
     | 
| 382 | 
         
            +
                        versions,
         
     | 
| 383 | 
         
            +
                        hidden_state_dim,
         
     | 
| 384 | 
         
            +
                        num_projection_vector=8,
         
     | 
| 385 | 
         
            +
                        zero_embedding_radio=0.1,
         
     | 
| 386 | 
         
            +
                        device="cuda",
         
     | 
| 387 | 
         
            +
                        precision="fp16",
         
     | 
| 388 | 
         
            +
                        normalize=False,
         
     | 
| 389 | 
         
            +
                        clip_max=0,
         
     | 
| 390 | 
         
            +
                        transform_type="base",
         
     | 
| 391 | 
         
            +
                        argument_p=0.2,
         
     | 
| 392 | 
         
            +
                ):
         
     | 
| 393 | 
         
            +
                    super().__init__()
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    self.device = torch.device(device)
         
     | 
| 396 | 
         
            +
                    self.hidden_state_dim = hidden_state_dim
         
     | 
| 397 | 
         
            +
                    self.zero_embedding_radio = zero_embedding_radio
         
     | 
| 398 | 
         
            +
                    self.num_projection_vector = num_projection_vector
         
     | 
| 399 | 
         
            +
                    self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
         
     | 
| 400 | 
         
            +
                    self.normalize = normalize
         
     | 
| 401 | 
         
            +
                    self.clip_max = clip_max
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    if transform_type == "base":
         
     | 
| 404 | 
         
            +
                        self.transform = transforms.Compose(
         
     | 
| 405 | 
         
            +
                            [
         
     | 
| 406 | 
         
            +
                                transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
         
     | 
| 407 | 
         
            +
                                transforms.CenterCrop(224),  # crop a (224, 224) square
         
     | 
| 408 | 
         
            +
                                transforms.Normalize(
         
     | 
| 409 | 
         
            +
                                    mean=[0.48145466, 0.4578275, 0.40821073],
         
     | 
| 410 | 
         
            +
                                    std=[0.26862954, 0.26130258, 0.27577711],
         
     | 
| 411 | 
         
            +
                                ),
         
     | 
| 412 | 
         
            +
                            ]
         
     | 
| 413 | 
         
            +
                        )
         
     | 
| 414 | 
         
            +
                    elif transform_type == "crop_blur_resize":
         
     | 
| 415 | 
         
            +
                        self.transform = transforms.Compose(
         
     | 
| 416 | 
         
            +
                            [
         
     | 
| 417 | 
         
            +
                                transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
         
     | 
| 418 | 
         
            +
                                transforms.CenterCrop(224),  # crop a (224, 224) square
         
     | 
| 419 | 
         
            +
                                transforms.RandomApply(
         
     | 
| 420 | 
         
            +
                                    transforms=[
         
     | 
| 421 | 
         
            +
                                        transforms.RandomResizedCrop(
         
     | 
| 422 | 
         
            +
                                            size=224,
         
     | 
| 423 | 
         
            +
                                            scale=(0.8, 1.0),
         
     | 
| 424 | 
         
            +
                                            ratio=(0.99, 1.01),
         
     | 
| 425 | 
         
            +
                                            interpolation=transforms.InterpolationMode.BICUBIC,
         
     | 
| 426 | 
         
            +
                                        ),
         
     | 
| 427 | 
         
            +
                                    ],
         
     | 
| 428 | 
         
            +
                                    p=argument_p,
         
     | 
| 429 | 
         
            +
                                ),
         
     | 
| 430 | 
         
            +
                                transforms.RandomApply(
         
     | 
| 431 | 
         
            +
                                    transforms=[
         
     | 
| 432 | 
         
            +
                                        transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
         
     | 
| 433 | 
         
            +
                                    ],
         
     | 
| 434 | 
         
            +
                                    p=argument_p,
         
     | 
| 435 | 
         
            +
                                ),
         
     | 
| 436 | 
         
            +
                                transforms.RandomApply(
         
     | 
| 437 | 
         
            +
                                    transforms=[
         
     | 
| 438 | 
         
            +
                                        RandomResize(size=224, resize_radio=(0.2, 1)),
         
     | 
| 439 | 
         
            +
                                    ],
         
     | 
| 440 | 
         
            +
                                    p=argument_p,
         
     | 
| 441 | 
         
            +
                                ),
         
     | 
| 442 | 
         
            +
                                transforms.Normalize(
         
     | 
| 443 | 
         
            +
                                    mean=[0.48145466, 0.4578275, 0.40821073],
         
     | 
| 444 | 
         
            +
                                    std=[0.26862954, 0.26130258, 0.27577711],
         
     | 
| 445 | 
         
            +
                                ),
         
     | 
| 446 | 
         
            +
                            ]
         
     | 
| 447 | 
         
            +
                        )
         
     | 
| 448 | 
         
            +
                    else:
         
     | 
| 449 | 
         
            +
                        raise ValueError(f"invalid {transform_type=}")
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    if isinstance(versions, str):
         
     | 
| 452 | 
         
            +
                        versions = (versions,)
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
         
     | 
| 455 | 
         
            +
                    clips = OrderedDict()
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    for v in versions:
         
     | 
| 458 | 
         
            +
                        # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
         
     | 
| 459 | 
         
            +
                        clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
         
     | 
| 460 | 
         
            +
                        delattr(clips[v], "transformer")
         
     | 
| 461 | 
         
            +
                        clips[v].eval()
         
     | 
| 462 | 
         
            +
                        clips[v].requires_grad_(False)
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                    if self.num_projection_vector == 0:
         
     | 
| 467 | 
         
            +
                        self.projection = nn.Identity()
         
     | 
| 468 | 
         
            +
                    else:
         
     | 
| 469 | 
         
            +
                        self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
         
     | 
| 470 | 
         
            +
                        self.projection.to(dtype=self.dtype)
         
     | 
| 471 | 
         
            +
                        nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                    self.clips = clips
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                    self._move_flag = False
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                def move(self):
         
     | 
| 478 | 
         
            +
                    if self._move_flag:
         
     | 
| 479 | 
         
            +
                        return
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                    def convert_weights(model: nn.Module):
         
     | 
| 482 | 
         
            +
                        """Convert applicable model parameters to fp16"""
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                        def _convert_weights_to_fp16(l):
         
     | 
| 485 | 
         
            +
                            if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
         
     | 
| 486 | 
         
            +
                                l.weight.data = l.weight.data.type(self.dtype)
         
     | 
| 487 | 
         
            +
                                if l.bias is not None:
         
     | 
| 488 | 
         
            +
                                    l.bias.data = l.bias.data.type(self.dtype)
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                            if isinstance(l, nn.MultiheadAttention):
         
     | 
| 491 | 
         
            +
                                for attr in [
         
     | 
| 492 | 
         
            +
                                    *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
         
     | 
| 493 | 
         
            +
                                    "in_proj_bias",
         
     | 
| 494 | 
         
            +
                                    "bias_k",
         
     | 
| 495 | 
         
            +
                                    "bias_v",
         
     | 
| 496 | 
         
            +
                                ]:
         
     | 
| 497 | 
         
            +
                                    tensor = getattr(l, attr)
         
     | 
| 498 | 
         
            +
                                    if tensor is not None:
         
     | 
| 499 | 
         
            +
                                        tensor.data = tensor.data.type(self.dtype)
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                            for name in ["text_projection", "proj"]:
         
     | 
| 502 | 
         
            +
                                if hasattr(l, name):
         
     | 
| 503 | 
         
            +
                                    attr = getattr(l, name)
         
     | 
| 504 | 
         
            +
                                    if attr is not None:
         
     | 
| 505 | 
         
            +
                                        attr.data = attr.data.type(self.dtype)
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                        model.apply(_convert_weights_to_fp16)
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                    for k in self.clips:
         
     | 
| 510 | 
         
            +
                        self.clips[k].to(self.device)
         
     | 
| 511 | 
         
            +
                        convert_weights(self.clips[k])  # fp32 -> self.dtype
         
     | 
| 512 | 
         
            +
                    self._move_flag = True
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                def unconditional_embedding(self, batch_size=None):
         
     | 
| 515 | 
         
            +
                    zero = torch.zeros(
         
     | 
| 516 | 
         
            +
                        batch_size,
         
     | 
| 517 | 
         
            +
                        self.clips_hidden_dim,
         
     | 
| 518 | 
         
            +
                        device=self.device,
         
     | 
| 519 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 520 | 
         
            +
                    )
         
     | 
| 521 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 522 | 
         
            +
                        zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
         
     | 
| 523 | 
         
            +
                    return zero
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                def convert_embedding(self, z):
         
     | 
| 526 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 527 | 
         
            +
                        z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
         
     | 
| 528 | 
         
            +
                    return z
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
         
     | 
| 531 | 
         
            +
                    if value_range is not None:
         
     | 
| 532 | 
         
            +
                        low, high = value_range
         
     | 
| 533 | 
         
            +
                        image = (image - low) / (high - low)
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                    image = self.transform(image)
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                    with torch.no_grad():
         
     | 
| 538 | 
         
            +
                        embs = []
         
     | 
| 539 | 
         
            +
                        for v in self.clips:
         
     | 
| 540 | 
         
            +
                            x = self.clips[v].encode_image(image)
         
     | 
| 541 | 
         
            +
                            if self.normalize:
         
     | 
| 542 | 
         
            +
                                x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
         
     | 
| 543 | 
         
            +
                                # clip_max only works with normalization
         
     | 
| 544 | 
         
            +
                                if self.clip_max > 0:
         
     | 
| 545 | 
         
            +
                                    x = x.clamp(-self.clip_max, self.clip_max)
         
     | 
| 546 | 
         
            +
                            embs.append(x)
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
                        z = torch.cat(embs, dim=-1)
         
     | 
| 549 | 
         
            +
                        if self.normalize:
         
     | 
| 550 | 
         
            +
                            z /= z.size(-1) ** 0.5
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                    if zero_embedding_radio > 0:
         
     | 
| 553 | 
         
            +
                        mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
         
     | 
| 554 | 
         
            +
                        z = z + mask.to(z)
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                    if self.num_projection_vector > 0:
         
     | 
| 557 | 
         
            +
                        z = self.projection(z).view(len(image), self.num_projection_vector, -1)
         
     | 
| 558 | 
         
            +
                    return z
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
                def encode(self, image):
         
     | 
| 561 | 
         
            +
                    self.move()
         
     | 
| 562 | 
         
            +
                    return self(image, zero_embedding_radio=self.zero_embedding_radio)
         
     | 
    	
        primitive_anything/michelangelo/models/modules/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .checkpoint import checkpoint
         
     | 
    	
        primitive_anything/michelangelo/models/modules/checkpoint.py
    ADDED
    
    | 
         @@ -0,0 +1,69 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
            Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from typing import Callable, Iterable, Sequence, Union
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def checkpoint(
         
     | 
| 11 | 
         
            +
                func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
         
     | 
| 12 | 
         
            +
                inputs: Sequence[torch.Tensor],
         
     | 
| 13 | 
         
            +
                params: Iterable[torch.Tensor],
         
     | 
| 14 | 
         
            +
                flag: bool,
         
     | 
| 15 | 
         
            +
                use_deepspeed: bool = False
         
     | 
| 16 | 
         
            +
            ):
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                Evaluate a function without caching intermediate activations, allowing for
         
     | 
| 19 | 
         
            +
                reduced memory at the expense of extra compute in the backward pass.
         
     | 
| 20 | 
         
            +
                :param func: the function to evaluate.
         
     | 
| 21 | 
         
            +
                :param inputs: the argument sequence to pass to `func`.
         
     | 
| 22 | 
         
            +
                :param params: a sequence of parameters `func` depends on but does not
         
     | 
| 23 | 
         
            +
                               explicitly take as arguments.
         
     | 
| 24 | 
         
            +
                :param flag: if False, disable gradient checkpointing.
         
     | 
| 25 | 
         
            +
                :param use_deepspeed: if True, use deepspeed
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                if flag:
         
     | 
| 28 | 
         
            +
                    if use_deepspeed:
         
     | 
| 29 | 
         
            +
                        import deepspeed
         
     | 
| 30 | 
         
            +
                        return deepspeed.checkpointing.checkpoint(func, *inputs)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    args = tuple(inputs) + tuple(params)
         
     | 
| 33 | 
         
            +
                    return CheckpointFunction.apply(func, len(inputs), *args)
         
     | 
| 34 | 
         
            +
                else:
         
     | 
| 35 | 
         
            +
                    return func(*inputs)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class CheckpointFunction(torch.autograd.Function):
         
     | 
| 39 | 
         
            +
                @staticmethod
         
     | 
| 40 | 
         
            +
                @torch.cuda.amp.custom_fwd
         
     | 
| 41 | 
         
            +
                def forward(ctx, run_function, length, *args):
         
     | 
| 42 | 
         
            +
                    ctx.run_function = run_function
         
     | 
| 43 | 
         
            +
                    ctx.input_tensors = list(args[:length])
         
     | 
| 44 | 
         
            +
                    ctx.input_params = list(args[length:])
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    with torch.no_grad():
         
     | 
| 47 | 
         
            +
                        output_tensors = ctx.run_function(*ctx.input_tensors)
         
     | 
| 48 | 
         
            +
                    return output_tensors
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @staticmethod
         
     | 
| 51 | 
         
            +
                @torch.cuda.amp.custom_bwd
         
     | 
| 52 | 
         
            +
                def backward(ctx, *output_grads):
         
     | 
| 53 | 
         
            +
                    ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
         
     | 
| 54 | 
         
            +
                    with torch.enable_grad():
         
     | 
| 55 | 
         
            +
                        # Fixes a bug where the first op in run_function modifies the
         
     | 
| 56 | 
         
            +
                        # Tensor storage in place, which is not allowed for detach()'d
         
     | 
| 57 | 
         
            +
                        # Tensors.
         
     | 
| 58 | 
         
            +
                        shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
         
     | 
| 59 | 
         
            +
                        output_tensors = ctx.run_function(*shallow_copies)
         
     | 
| 60 | 
         
            +
                    input_grads = torch.autograd.grad(
         
     | 
| 61 | 
         
            +
                        output_tensors,
         
     | 
| 62 | 
         
            +
                        ctx.input_tensors + ctx.input_params,
         
     | 
| 63 | 
         
            +
                        output_grads,
         
     | 
| 64 | 
         
            +
                        allow_unused=True,
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
                    del ctx.input_tensors
         
     | 
| 67 | 
         
            +
                    del ctx.input_params
         
     | 
| 68 | 
         
            +
                    del output_tensors
         
     | 
| 69 | 
         
            +
                    return (None, None) + input_grads
         
     | 
    	
        primitive_anything/michelangelo/models/modules/diffusion_transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,218 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from typing import Optional
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .checkpoint import checkpoint
         
     | 
| 9 | 
         
            +
            from .transformer_blocks import (
         
     | 
| 10 | 
         
            +
                init_linear,
         
     | 
| 11 | 
         
            +
                MLP,
         
     | 
| 12 | 
         
            +
                MultiheadCrossAttention,
         
     | 
| 13 | 
         
            +
                MultiheadAttention,
         
     | 
| 14 | 
         
            +
                ResidualAttentionBlock
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class AdaLayerNorm(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(self,
         
     | 
| 20 | 
         
            +
                             device: torch.device,
         
     | 
| 21 | 
         
            +
                             dtype: torch.dtype,
         
     | 
| 22 | 
         
            +
                             width: int):
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.silu = nn.SiLU(inplace=True)
         
     | 
| 27 | 
         
            +
                    self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
         
     | 
| 28 | 
         
            +
                    self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def forward(self, x, timestep):
         
     | 
| 31 | 
         
            +
                    emb = self.linear(timestep)
         
     | 
| 32 | 
         
            +
                    scale, shift = torch.chunk(emb, 2, dim=2)
         
     | 
| 33 | 
         
            +
                    x = self.layernorm(x) * (1 + scale) + shift
         
     | 
| 34 | 
         
            +
                    return x
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            class DitBlock(nn.Module):
         
     | 
| 38 | 
         
            +
                def __init__(
         
     | 
| 39 | 
         
            +
                        self,
         
     | 
| 40 | 
         
            +
                        *,
         
     | 
| 41 | 
         
            +
                        device: torch.device,
         
     | 
| 42 | 
         
            +
                        dtype: torch.dtype,
         
     | 
| 43 | 
         
            +
                        n_ctx: int,
         
     | 
| 44 | 
         
            +
                        width: int,
         
     | 
| 45 | 
         
            +
                        heads: int,
         
     | 
| 46 | 
         
            +
                        context_dim: int,
         
     | 
| 47 | 
         
            +
                        qkv_bias: bool = False,
         
     | 
| 48 | 
         
            +
                        init_scale: float = 1.0,
         
     | 
| 49 | 
         
            +
                        use_checkpoint: bool = False
         
     | 
| 50 | 
         
            +
                ):
         
     | 
| 51 | 
         
            +
                    super().__init__()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.use_checkpoint = use_checkpoint
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    self.attn = MultiheadAttention(
         
     | 
| 56 | 
         
            +
                        device=device,
         
     | 
| 57 | 
         
            +
                        dtype=dtype,
         
     | 
| 58 | 
         
            +
                        n_ctx=n_ctx,
         
     | 
| 59 | 
         
            +
                        width=width,
         
     | 
| 60 | 
         
            +
                        heads=heads,
         
     | 
| 61 | 
         
            +
                        init_scale=init_scale,
         
     | 
| 62 | 
         
            +
                        qkv_bias=qkv_bias
         
     | 
| 63 | 
         
            +
                    )
         
     | 
| 64 | 
         
            +
                    self.ln_1 = AdaLayerNorm(device, dtype, width)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    if context_dim is not None:
         
     | 
| 67 | 
         
            +
                        self.ln_2 = AdaLayerNorm(device, dtype, width)
         
     | 
| 68 | 
         
            +
                        self.cross_attn = MultiheadCrossAttention(
         
     | 
| 69 | 
         
            +
                            device=device,
         
     | 
| 70 | 
         
            +
                            dtype=dtype,
         
     | 
| 71 | 
         
            +
                            width=width,
         
     | 
| 72 | 
         
            +
                            heads=heads,
         
     | 
| 73 | 
         
            +
                            data_width=context_dim,
         
     | 
| 74 | 
         
            +
                            init_scale=init_scale,
         
     | 
| 75 | 
         
            +
                            qkv_bias=qkv_bias
         
     | 
| 76 | 
         
            +
                        )
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
         
     | 
| 79 | 
         
            +
                    self.ln_3 = AdaLayerNorm(device, dtype, width)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
         
     | 
| 82 | 
         
            +
                    return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
         
     | 
| 85 | 
         
            +
                    x = x + self.attn(self.ln_1(x, t))
         
     | 
| 86 | 
         
            +
                    if context is not None:
         
     | 
| 87 | 
         
            +
                        x = x + self.cross_attn(self.ln_2(x, t), context)
         
     | 
| 88 | 
         
            +
                    x = x + self.mlp(self.ln_3(x, t))
         
     | 
| 89 | 
         
            +
                    return x
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            class DiT(nn.Module):
         
     | 
| 93 | 
         
            +
                def __init__(
         
     | 
| 94 | 
         
            +
                        self,
         
     | 
| 95 | 
         
            +
                        *,
         
     | 
| 96 | 
         
            +
                        device: Optional[torch.device],
         
     | 
| 97 | 
         
            +
                        dtype: Optional[torch.dtype],
         
     | 
| 98 | 
         
            +
                        n_ctx: int,
         
     | 
| 99 | 
         
            +
                        width: int,
         
     | 
| 100 | 
         
            +
                        layers: int,
         
     | 
| 101 | 
         
            +
                        heads: int,
         
     | 
| 102 | 
         
            +
                        context_dim: int,
         
     | 
| 103 | 
         
            +
                        init_scale: float = 0.25,
         
     | 
| 104 | 
         
            +
                        qkv_bias: bool = False,
         
     | 
| 105 | 
         
            +
                        use_checkpoint: bool = False
         
     | 
| 106 | 
         
            +
                ):
         
     | 
| 107 | 
         
            +
                    super().__init__()
         
     | 
| 108 | 
         
            +
                    self.n_ctx = n_ctx
         
     | 
| 109 | 
         
            +
                    self.width = width
         
     | 
| 110 | 
         
            +
                    self.layers = layers
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    self.resblocks = nn.ModuleList(
         
     | 
| 113 | 
         
            +
                        [
         
     | 
| 114 | 
         
            +
                            DitBlock(
         
     | 
| 115 | 
         
            +
                                device=device,
         
     | 
| 116 | 
         
            +
                                dtype=dtype,
         
     | 
| 117 | 
         
            +
                                n_ctx=n_ctx,
         
     | 
| 118 | 
         
            +
                                width=width,
         
     | 
| 119 | 
         
            +
                                heads=heads,
         
     | 
| 120 | 
         
            +
                                context_dim=context_dim,
         
     | 
| 121 | 
         
            +
                                qkv_bias=qkv_bias,
         
     | 
| 122 | 
         
            +
                                init_scale=init_scale,
         
     | 
| 123 | 
         
            +
                                use_checkpoint=use_checkpoint
         
     | 
| 124 | 
         
            +
                            )
         
     | 
| 125 | 
         
            +
                            for _ in range(layers)
         
     | 
| 126 | 
         
            +
                        ]
         
     | 
| 127 | 
         
            +
                    )
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
         
     | 
| 130 | 
         
            +
                    for block in self.resblocks:
         
     | 
| 131 | 
         
            +
                        x = block(x, t, context)
         
     | 
| 132 | 
         
            +
                    return x
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            class UNetDiffusionTransformer(nn.Module):
         
     | 
| 136 | 
         
            +
                def __init__(
         
     | 
| 137 | 
         
            +
                        self,
         
     | 
| 138 | 
         
            +
                        *,
         
     | 
| 139 | 
         
            +
                        device: Optional[torch.device],
         
     | 
| 140 | 
         
            +
                        dtype: Optional[torch.dtype],
         
     | 
| 141 | 
         
            +
                        n_ctx: int,
         
     | 
| 142 | 
         
            +
                        width: int,
         
     | 
| 143 | 
         
            +
                        layers: int,
         
     | 
| 144 | 
         
            +
                        heads: int,
         
     | 
| 145 | 
         
            +
                        init_scale: float = 0.25,
         
     | 
| 146 | 
         
            +
                        qkv_bias: bool = False,
         
     | 
| 147 | 
         
            +
                        skip_ln: bool = False,
         
     | 
| 148 | 
         
            +
                        use_checkpoint: bool = False
         
     | 
| 149 | 
         
            +
                ):
         
     | 
| 150 | 
         
            +
                    super().__init__()
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    self.n_ctx = n_ctx
         
     | 
| 153 | 
         
            +
                    self.width = width
         
     | 
| 154 | 
         
            +
                    self.layers = layers
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    self.encoder = nn.ModuleList()
         
     | 
| 157 | 
         
            +
                    for _ in range(layers):
         
     | 
| 158 | 
         
            +
                        resblock = ResidualAttentionBlock(
         
     | 
| 159 | 
         
            +
                            device=device,
         
     | 
| 160 | 
         
            +
                            dtype=dtype,
         
     | 
| 161 | 
         
            +
                            n_ctx=n_ctx,
         
     | 
| 162 | 
         
            +
                            width=width,
         
     | 
| 163 | 
         
            +
                            heads=heads,
         
     | 
| 164 | 
         
            +
                            init_scale=init_scale,
         
     | 
| 165 | 
         
            +
                            qkv_bias=qkv_bias,
         
     | 
| 166 | 
         
            +
                            use_checkpoint=use_checkpoint
         
     | 
| 167 | 
         
            +
                        )
         
     | 
| 168 | 
         
            +
                        self.encoder.append(resblock)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    self.middle_block = ResidualAttentionBlock(
         
     | 
| 171 | 
         
            +
                        device=device,
         
     | 
| 172 | 
         
            +
                        dtype=dtype,
         
     | 
| 173 | 
         
            +
                        n_ctx=n_ctx,
         
     | 
| 174 | 
         
            +
                        width=width,
         
     | 
| 175 | 
         
            +
                        heads=heads,
         
     | 
| 176 | 
         
            +
                        init_scale=init_scale,
         
     | 
| 177 | 
         
            +
                        qkv_bias=qkv_bias,
         
     | 
| 178 | 
         
            +
                        use_checkpoint=use_checkpoint
         
     | 
| 179 | 
         
            +
                    )
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    self.decoder = nn.ModuleList()
         
     | 
| 182 | 
         
            +
                    for _ in range(layers):
         
     | 
| 183 | 
         
            +
                        resblock = ResidualAttentionBlock(
         
     | 
| 184 | 
         
            +
                            device=device,
         
     | 
| 185 | 
         
            +
                            dtype=dtype,
         
     | 
| 186 | 
         
            +
                            n_ctx=n_ctx,
         
     | 
| 187 | 
         
            +
                            width=width,
         
     | 
| 188 | 
         
            +
                            heads=heads,
         
     | 
| 189 | 
         
            +
                            init_scale=init_scale,
         
     | 
| 190 | 
         
            +
                            qkv_bias=qkv_bias,
         
     | 
| 191 | 
         
            +
                            use_checkpoint=use_checkpoint
         
     | 
| 192 | 
         
            +
                        )
         
     | 
| 193 | 
         
            +
                        linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
         
     | 
| 194 | 
         
            +
                        init_linear(linear, init_scale)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                        layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                        self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    enc_outputs = []
         
     | 
| 203 | 
         
            +
                    for block in self.encoder:
         
     | 
| 204 | 
         
            +
                        x = block(x)
         
     | 
| 205 | 
         
            +
                        enc_outputs.append(x)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    x = self.middle_block(x)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
         
     | 
| 210 | 
         
            +
                        x = torch.cat([enc_outputs.pop(), x], dim=-1)
         
     | 
| 211 | 
         
            +
                        x = linear(x)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                        if layer_norm is not None:
         
     | 
| 214 | 
         
            +
                            x = layer_norm(x)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                        x = resblock(x)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    return x
         
     |