Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		tokenid
		
	commited on
		
		
					Commit 
							
							·
						
						ad06aed
	
1
								Parent(s):
							
							6af576b
								
init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +1 -0
 - README.md +1 -1
 - app.py +349 -0
 - configs/instant-mesh-base.yaml +22 -0
 - configs/instant-mesh-large.yaml +22 -0
 - configs/instant-nerf-base.yaml +21 -0
 - configs/instant-nerf-large.yaml +21 -0
 - examples/bird.jpg +0 -0
 - examples/bubble_mart_blue.png +0 -0
 - examples/cake.jpg +0 -0
 - examples/cartoon_dinosaur.png +0 -0
 - examples/cartoon_girl.jpg +0 -0
 - examples/chair_comfort.jpg +0 -0
 - examples/chair_wood.jpg +0 -0
 - examples/chest.jpg +0 -0
 - examples/cube.png +0 -0
 - examples/extinguisher.png +0 -0
 - examples/fruit_bycycle.jpg +0 -0
 - examples/fruit_elephant.jpg +0 -0
 - examples/genshin_building.png +0 -0
 - examples/house2.jpg +0 -0
 - examples/kunkun.png +0 -0
 - examples/mushroom_teapot.jpg +0 -0
 - examples/pikachu.png +0 -0
 - examples/pistol.png +0 -0
 - examples/plant.jpg +0 -0
 - examples/robot.jpg +0 -0
 - examples/sea_turtle.png +0 -0
 - examples/skating_shoe.jpg +0 -0
 - examples/sorting_board.png +0 -0
 - examples/sword.png +0 -0
 - examples/toy_car.jpg +0 -0
 - examples/toyduck.png +0 -0
 - examples/watermelon.png +0 -0
 - examples/whitedog.png +0 -0
 - examples/x_cube.jpg +0 -0
 - examples/x_teapot.jpg +0 -0
 - examples/x_toyduck.jpg +0 -0
 - requirements.txt +21 -0
 - src/__init__.py +0 -0
 - src/data/__init__.py +0 -0
 - src/data/objaverse.py +329 -0
 - src/model.py +310 -0
 - src/model_mesh.py +325 -0
 - src/models/__init__.py +0 -0
 - src/models/decoder/__init__.py +0 -0
 - src/models/decoder/transformer.py +123 -0
 - src/models/encoder/__init__.py +0 -0
 - src/models/encoder/dino.py +550 -0
 - src/models/encoder/dino_wrapper.py +80 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            __pycache__
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -7,7 +7,7 @@ sdk: gradio 
     | 
|
| 7 | 
         
             
            sdk_version: 4.25.0
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
            -
            license:  
     | 
| 11 | 
         
             
            ---
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
| 
         | 
|
| 7 | 
         
             
            sdk_version: 4.25.0
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
            +
            license: apache-2.0
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,349 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import imageio
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import rembg
         
     | 
| 6 | 
         
            +
            from PIL import Image
         
     | 
| 7 | 
         
            +
            from torchvision.transforms import v2
         
     | 
| 8 | 
         
            +
            from pytorch_lightning import seed_everything
         
     | 
| 9 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 10 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 11 | 
         
            +
            from tqdm import tqdm
         
     | 
| 12 | 
         
            +
            from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from src.utils.train_util import instantiate_from_config
         
     | 
| 15 | 
         
            +
            from src.utils.camera_util import (
         
     | 
| 16 | 
         
            +
                FOV_to_intrinsics, 
         
     | 
| 17 | 
         
            +
                get_zero123plus_input_cameras,
         
     | 
| 18 | 
         
            +
                get_circular_camera_poses,
         
     | 
| 19 | 
         
            +
            )
         
     | 
| 20 | 
         
            +
            from src.utils.mesh_util import save_obj
         
     | 
| 21 | 
         
            +
            from src.utils.infer_util import remove_background, resize_foreground, images_to_video
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            import tempfile
         
     | 
| 24 | 
         
            +
            from functools import partial
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 27 | 
         
            +
            import spaces
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                Get the rendering camera parameters.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
         
     | 
| 35 | 
         
            +
                if is_flexicubes:
         
     | 
| 36 | 
         
            +
                    cameras = torch.linalg.inv(c2ws)
         
     | 
| 37 | 
         
            +
                    cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
         
     | 
| 38 | 
         
            +
                else:
         
     | 
| 39 | 
         
            +
                    extrinsics = c2ws.flatten(-2)
         
     | 
| 40 | 
         
            +
                    intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
         
     | 
| 41 | 
         
            +
                    cameras = torch.cat([extrinsics, intrinsics], dim=-1)
         
     | 
| 42 | 
         
            +
                    cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
         
     | 
| 43 | 
         
            +
                return cameras
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def images_to_video(images, output_path, fps=30):
         
     | 
| 47 | 
         
            +
                # images: (N, C, H, W)
         
     | 
| 48 | 
         
            +
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
         
     | 
| 49 | 
         
            +
                frames = []
         
     | 
| 50 | 
         
            +
                for i in range(images.shape[0]):
         
     | 
| 51 | 
         
            +
                    frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
         
     | 
| 52 | 
         
            +
                    assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
         
     | 
| 53 | 
         
            +
                        f"Frame shape mismatch: {frame.shape} vs {images.shape}"
         
     | 
| 54 | 
         
            +
                    assert frame.min() >= 0 and frame.max() <= 255, \
         
     | 
| 55 | 
         
            +
                        f"Frame value out of range: {frame.min()} ~ {frame.max()}"
         
     | 
| 56 | 
         
            +
                    frames.append(frame)
         
     | 
| 57 | 
         
            +
                imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            ###############################################################################
         
     | 
| 61 | 
         
            +
            # Configuration.
         
     | 
| 62 | 
         
            +
            ###############################################################################
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            config_path = 'configs/instant-mesh-large-eval.yaml'
         
     | 
| 65 | 
         
            +
            config = OmegaConf.load(config_path)
         
     | 
| 66 | 
         
            +
            config_name = os.path.basename(config_path).replace('.yaml', '')
         
     | 
| 67 | 
         
            +
            model_config = config.model_config
         
     | 
| 68 | 
         
            +
            infer_config = config.infer_config
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            device = torch.device('cuda')
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # load diffusion model
         
     | 
| 75 | 
         
            +
            print('Loading diffusion model ...')
         
     | 
| 76 | 
         
            +
            pipeline = DiffusionPipeline.from_pretrained(
         
     | 
| 77 | 
         
            +
                "sudo-ai/zero123plus-v1.2", 
         
     | 
| 78 | 
         
            +
                custom_pipeline="zero123plus",
         
     | 
| 79 | 
         
            +
                torch_dtype=torch.float16,
         
     | 
| 80 | 
         
            +
            )
         
     | 
| 81 | 
         
            +
            pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
         
     | 
| 82 | 
         
            +
                pipeline.scheduler.config, timestep_spacing='trailing'
         
     | 
| 83 | 
         
            +
            )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # load custom white-background UNet
         
     | 
| 86 | 
         
            +
            unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
         
     | 
| 87 | 
         
            +
            state_dict = torch.load(unet_ckpt_path, map_location='cpu')
         
     | 
| 88 | 
         
            +
            pipeline.unet.load_state_dict(state_dict, strict=True)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            pipeline = pipeline.to(device)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            # load reconstruction model
         
     | 
| 93 | 
         
            +
            print('Loading reconstruction model ...')
         
     | 
| 94 | 
         
            +
            model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
         
     | 
| 95 | 
         
            +
            model = instantiate_from_config(model_config)
         
     | 
| 96 | 
         
            +
            state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
         
     | 
| 97 | 
         
            +
            state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
         
     | 
| 98 | 
         
            +
            model.load_state_dict(state_dict, strict=True)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            model = model.to(device)
         
     | 
| 101 | 
         
            +
            if IS_FLEXICUBES:
         
     | 
| 102 | 
         
            +
                model.init_flexicubes_geometry(device)
         
     | 
| 103 | 
         
            +
            model = model.eval()
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            print('Loading Finished!')
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def check_input_image(input_image):
         
     | 
| 109 | 
         
            +
                if input_image is None:
         
     | 
| 110 | 
         
            +
                    raise gr.Error("No image uploaded!")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def preprocess(input_image, do_remove_background):
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                rembg_session = rembg.new_session() if do_remove_background else None
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                if do_remove_background:
         
     | 
| 118 | 
         
            +
                    input_image = remove_background(input_image, rembg_session)
         
     | 
| 119 | 
         
            +
                    input_image = resize_foreground(input_image, 0.85)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                return input_image
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            @spaces.GPU
         
     | 
| 125 | 
         
            +
            def generate_mvs(input_image, sample_steps, sample_seed):
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                seed_everything(sample_seed)
         
     | 
| 128 | 
         
            +
                
         
     | 
| 129 | 
         
            +
                # sampling
         
     | 
| 130 | 
         
            +
                z123_image = pipeline(
         
     | 
| 131 | 
         
            +
                    input_image, 
         
     | 
| 132 | 
         
            +
                    num_inference_steps=sample_steps
         
     | 
| 133 | 
         
            +
                ).images[0]
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                show_image = np.asarray(z123_image, dtype=np.uint8)
         
     | 
| 136 | 
         
            +
                show_image = torch.from_numpy(show_image)     # (960, 640, 3)
         
     | 
| 137 | 
         
            +
                show_image = rearrange(show_image, '(n h) (m w) c -> (m h) (n w) c', n=3, m=2)
         
     | 
| 138 | 
         
            +
                show_image = Image.fromarray(show_image.numpy())
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                return z123_image, show_image
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            @spaces.GPU
         
     | 
| 144 | 
         
            +
            def make_mesh(mesh_fpath, planes):
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
         
     | 
| 147 | 
         
            +
                mesh_dirname = os.path.dirname(mesh_fpath)
         
     | 
| 148 | 
         
            +
                mesh_vis_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
         
     | 
| 149 | 
         
            +
                    
         
     | 
| 150 | 
         
            +
                with torch.no_grad():
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    # get mesh
         
     | 
| 153 | 
         
            +
                    mesh_out = model.extract_mesh(
         
     | 
| 154 | 
         
            +
                        planes,
         
     | 
| 155 | 
         
            +
                        use_texture_map=False,
         
     | 
| 156 | 
         
            +
                        **infer_config,
         
     | 
| 157 | 
         
            +
                    )
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    vertices, faces, vertex_colors = mesh_out
         
     | 
| 160 | 
         
            +
                    vertices = vertices[:, [0, 2, 1]]
         
     | 
| 161 | 
         
            +
                    vertices[:, -1] *= -1
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    save_obj(vertices, faces, vertex_colors, mesh_fpath)
         
     | 
| 164 | 
         
            +
                    
         
     | 
| 165 | 
         
            +
                    print(f"Mesh saved to {mesh_fpath}")
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                return mesh_fpath
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            @spaces.GPU
         
     | 
| 171 | 
         
            +
            def make3d(images):
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                images = np.asarray(images, dtype=np.float32) / 255.0
         
     | 
| 174 | 
         
            +
                images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()     # (3, 960, 640)
         
     | 
| 175 | 
         
            +
                images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)        # (6, 3, 320, 320)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device)
         
     | 
| 178 | 
         
            +
                render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                images = images.unsqueeze(0).to(device)
         
     | 
| 181 | 
         
            +
                images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
         
     | 
| 184 | 
         
            +
                print(mesh_fpath)
         
     | 
| 185 | 
         
            +
                mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
         
     | 
| 186 | 
         
            +
                mesh_dirname = os.path.dirname(mesh_fpath)
         
     | 
| 187 | 
         
            +
                video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                with torch.no_grad():
         
     | 
| 190 | 
         
            +
                    # get triplane
         
     | 
| 191 | 
         
            +
                    planes = model.forward_planes(images, input_cameras)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    # get video
         
     | 
| 194 | 
         
            +
                    chunk_size = 20 if IS_FLEXICUBES else 1
         
     | 
| 195 | 
         
            +
                    render_size = 384
         
     | 
| 196 | 
         
            +
                    
         
     | 
| 197 | 
         
            +
                    frames = []
         
     | 
| 198 | 
         
            +
                    for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
         
     | 
| 199 | 
         
            +
                        if IS_FLEXICUBES:
         
     | 
| 200 | 
         
            +
                            frame = model.forward_geometry(
         
     | 
| 201 | 
         
            +
                                planes,
         
     | 
| 202 | 
         
            +
                                render_cameras[:, i:i+chunk_size],
         
     | 
| 203 | 
         
            +
                                render_size=render_size,
         
     | 
| 204 | 
         
            +
                            )['img']
         
     | 
| 205 | 
         
            +
                        else:
         
     | 
| 206 | 
         
            +
                            frame = model.synthesizer(
         
     | 
| 207 | 
         
            +
                                planes,
         
     | 
| 208 | 
         
            +
                                cameras=render_cameras[:, i:i+chunk_size],
         
     | 
| 209 | 
         
            +
                                render_size=render_size,
         
     | 
| 210 | 
         
            +
                            )['images_rgb']
         
     | 
| 211 | 
         
            +
                        frames.append(frame)
         
     | 
| 212 | 
         
            +
                    frames = torch.cat(frames, dim=1)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    images_to_video(
         
     | 
| 215 | 
         
            +
                        frames[0],
         
     | 
| 216 | 
         
            +
                        video_fpath,
         
     | 
| 217 | 
         
            +
                        fps=30,
         
     | 
| 218 | 
         
            +
                    )
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    print(f"Video saved to {video_fpath}")
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                mesh_fpath = make_mesh(mesh_fpath, planes)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                return video_fpath, mesh_fpath
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            import gradio as gr
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            _HEADER_ = '''
         
     | 
| 230 | 
         
            +
            <h2><b>Official 🤗 Gradio demo for</b>
         
     | 
| 231 | 
         
            +
            <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
         
     | 
| 232 | 
         
            +
            <b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
         
     | 
| 233 | 
         
            +
            </a>.
         
     | 
| 234 | 
         
            +
            </h2>
         
     | 
| 235 | 
         
            +
            '''
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
            _LINKS_ = '''
         
     | 
| 238 | 
         
            +
            <h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
         
     | 
| 239 | 
         
            +
            <h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
         
     | 
| 240 | 
         
            +
            '''
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
            _CITE_ = r"""
         
     | 
| 243 | 
         
            +
            ```bibtex
         
     | 
| 244 | 
         
            +
            @article{xu2024instantmesh,
         
     | 
| 245 | 
         
            +
              title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
         
     | 
| 246 | 
         
            +
              author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
         
     | 
| 247 | 
         
            +
              journal={arXiv preprint arXiv:2404.07191},
         
     | 
| 248 | 
         
            +
              year={2024}
         
     | 
| 249 | 
         
            +
            }
         
     | 
| 250 | 
         
            +
            ```
         
     | 
| 251 | 
         
            +
            """
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 255 | 
         
            +
                gr.Markdown(_HEADER_)
         
     | 
| 256 | 
         
            +
                with gr.Row(variant="panel"):
         
     | 
| 257 | 
         
            +
                    with gr.Column():
         
     | 
| 258 | 
         
            +
                        with gr.Row():
         
     | 
| 259 | 
         
            +
                            input_image = gr.Image(
         
     | 
| 260 | 
         
            +
                                label="Input Image",
         
     | 
| 261 | 
         
            +
                                image_mode="RGBA",
         
     | 
| 262 | 
         
            +
                                sources="upload",
         
     | 
| 263 | 
         
            +
                                width=256,
         
     | 
| 264 | 
         
            +
                                height=256,
         
     | 
| 265 | 
         
            +
                                type="pil",
         
     | 
| 266 | 
         
            +
                                elem_id="content_image",
         
     | 
| 267 | 
         
            +
                            )
         
     | 
| 268 | 
         
            +
                            processed_image = gr.Image(
         
     | 
| 269 | 
         
            +
                                label="Processed Image", 
         
     | 
| 270 | 
         
            +
                                image_mode="RGBA", 
         
     | 
| 271 | 
         
            +
                                width=256,
         
     | 
| 272 | 
         
            +
                                height=256,
         
     | 
| 273 | 
         
            +
                                type="pil", 
         
     | 
| 274 | 
         
            +
                                interactive=False
         
     | 
| 275 | 
         
            +
                            )
         
     | 
| 276 | 
         
            +
                        with gr.Row():
         
     | 
| 277 | 
         
            +
                            with gr.Group():
         
     | 
| 278 | 
         
            +
                                do_remove_background = gr.Checkbox(
         
     | 
| 279 | 
         
            +
                                    label="Remove Background", value=True
         
     | 
| 280 | 
         
            +
                                )
         
     | 
| 281 | 
         
            +
                                sample_seed = gr.Number(value=42, label="Seed  (Try a different value if the result is unsatisfying)", precision=0)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                                sample_steps = gr.Slider(
         
     | 
| 284 | 
         
            +
                                    label="Sample Steps",
         
     | 
| 285 | 
         
            +
                                    minimum=30,
         
     | 
| 286 | 
         
            +
                                    maximum=75,
         
     | 
| 287 | 
         
            +
                                    value=75,
         
     | 
| 288 | 
         
            +
                                    step=5
         
     | 
| 289 | 
         
            +
                                )
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                        with gr.Row():
         
     | 
| 292 | 
         
            +
                            submit = gr.Button("Generate", elem_id="generate", variant="primary")
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                        with gr.Row(variant="panel"):
         
     | 
| 295 | 
         
            +
                            gr.Examples(
         
     | 
| 296 | 
         
            +
                                examples=[
         
     | 
| 297 | 
         
            +
                                    os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
         
     | 
| 298 | 
         
            +
                                ],
         
     | 
| 299 | 
         
            +
                                inputs=[input_image],
         
     | 
| 300 | 
         
            +
                                label="Examples",
         
     | 
| 301 | 
         
            +
                                examples_per_page=20
         
     | 
| 302 | 
         
            +
                            )
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    with gr.Column():
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                        with gr.Row():
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                            with gr.Column():
         
     | 
| 309 | 
         
            +
                                mv_show_images = gr.Image(
         
     | 
| 310 | 
         
            +
                                    label="Generated Multi-views",
         
     | 
| 311 | 
         
            +
                                    type="pil",
         
     | 
| 312 | 
         
            +
                                    width=379,
         
     | 
| 313 | 
         
            +
                                    interactive=False
         
     | 
| 314 | 
         
            +
                                )
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                            with gr.Column():
         
     | 
| 317 | 
         
            +
                                output_video = gr.Video(
         
     | 
| 318 | 
         
            +
                                    label="video", format="mp4",
         
     | 
| 319 | 
         
            +
                                    width=379,
         
     | 
| 320 | 
         
            +
                                    autoplay=True,
         
     | 
| 321 | 
         
            +
                                    interactive=False
         
     | 
| 322 | 
         
            +
                                )
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                        with gr.Row():
         
     | 
| 325 | 
         
            +
                            output_model_obj = gr.Model3D(
         
     | 
| 326 | 
         
            +
                                label="Output Model (OBJ Format)",
         
     | 
| 327 | 
         
            +
                                width=768,
         
     | 
| 328 | 
         
            +
                                interactive=False,
         
     | 
| 329 | 
         
            +
                            )
         
     | 
| 330 | 
         
            +
                gr.Markdown(_LINKS_)
         
     | 
| 331 | 
         
            +
                gr.Markdown(_CITE_)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                mv_images = gr.State()
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                submit.click(fn=check_input_image, inputs=[input_image]).success(
         
     | 
| 336 | 
         
            +
                    fn=preprocess,
         
     | 
| 337 | 
         
            +
                    inputs=[input_image, do_remove_background],
         
     | 
| 338 | 
         
            +
                    outputs=[processed_image],
         
     | 
| 339 | 
         
            +
                ).success(
         
     | 
| 340 | 
         
            +
                    fn=generate_mvs,
         
     | 
| 341 | 
         
            +
                    inputs=[processed_image, sample_steps, sample_seed],
         
     | 
| 342 | 
         
            +
                    outputs=[mv_images, mv_show_images],
         
     | 
| 343 | 
         
            +
                ).success(
         
     | 
| 344 | 
         
            +
                    fn=make3d,
         
     | 
| 345 | 
         
            +
                    inputs=[mv_images],
         
     | 
| 346 | 
         
            +
                    outputs=[output_video, output_model_obj]
         
     | 
| 347 | 
         
            +
                )
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
            demo.launch()
         
     | 
    	
        configs/instant-mesh-base.yaml
    ADDED
    
    | 
         @@ -0,0 +1,22 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_config:
         
     | 
| 2 | 
         
            +
              target: src.models.lrm_mesh.InstantMesh
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                encoder_feat_dim: 768
         
     | 
| 5 | 
         
            +
                encoder_freeze: false
         
     | 
| 6 | 
         
            +
                encoder_model_name: facebook/dino-vitb16
         
     | 
| 7 | 
         
            +
                transformer_dim: 1024
         
     | 
| 8 | 
         
            +
                transformer_layers: 12
         
     | 
| 9 | 
         
            +
                transformer_heads: 16
         
     | 
| 10 | 
         
            +
                triplane_low_res: 32
         
     | 
| 11 | 
         
            +
                triplane_high_res: 64
         
     | 
| 12 | 
         
            +
                triplane_dim: 40
         
     | 
| 13 | 
         
            +
                rendering_samples_per_ray: 96
         
     | 
| 14 | 
         
            +
                grid_res: 128
         
     | 
| 15 | 
         
            +
                grid_scale: 2.1
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            infer_config:
         
     | 
| 19 | 
         
            +
              unet_path: ckpts/diffusion_pytorch_model.bin
         
     | 
| 20 | 
         
            +
              model_path: ckpts/instant_mesh_base.ckpt
         
     | 
| 21 | 
         
            +
              texture_resolution: 1024
         
     | 
| 22 | 
         
            +
              render_resolution: 512
         
     | 
    	
        configs/instant-mesh-large.yaml
    ADDED
    
    | 
         @@ -0,0 +1,22 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_config:
         
     | 
| 2 | 
         
            +
              target: src.models.lrm_mesh.InstantMesh
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                encoder_feat_dim: 768
         
     | 
| 5 | 
         
            +
                encoder_freeze: false
         
     | 
| 6 | 
         
            +
                encoder_model_name: facebook/dino-vitb16
         
     | 
| 7 | 
         
            +
                transformer_dim: 1024
         
     | 
| 8 | 
         
            +
                transformer_layers: 16
         
     | 
| 9 | 
         
            +
                transformer_heads: 16
         
     | 
| 10 | 
         
            +
                triplane_low_res: 32
         
     | 
| 11 | 
         
            +
                triplane_high_res: 64
         
     | 
| 12 | 
         
            +
                triplane_dim: 80
         
     | 
| 13 | 
         
            +
                rendering_samples_per_ray: 128
         
     | 
| 14 | 
         
            +
                grid_res: 128
         
     | 
| 15 | 
         
            +
                grid_scale: 2.1
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            infer_config:
         
     | 
| 19 | 
         
            +
              unet_path: ckpts/diffusion_pytorch_model.bin
         
     | 
| 20 | 
         
            +
              model_path: ckpts/instant_mesh_large.ckpt
         
     | 
| 21 | 
         
            +
              texture_resolution: 1024
         
     | 
| 22 | 
         
            +
              render_resolution: 512
         
     | 
    	
        configs/instant-nerf-base.yaml
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_config:
         
     | 
| 2 | 
         
            +
              target: src.models.lrm.InstantNeRF
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                encoder_feat_dim: 768
         
     | 
| 5 | 
         
            +
                encoder_freeze: false
         
     | 
| 6 | 
         
            +
                encoder_model_name: facebook/dino-vitb16
         
     | 
| 7 | 
         
            +
                transformer_dim: 1024
         
     | 
| 8 | 
         
            +
                transformer_layers: 12
         
     | 
| 9 | 
         
            +
                transformer_heads: 16
         
     | 
| 10 | 
         
            +
                triplane_low_res: 32
         
     | 
| 11 | 
         
            +
                triplane_high_res: 64
         
     | 
| 12 | 
         
            +
                triplane_dim: 40
         
     | 
| 13 | 
         
            +
                rendering_samples_per_ray: 96
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            infer_config:
         
     | 
| 17 | 
         
            +
              unet_path: ckpts/diffusion_pytorch_model.bin
         
     | 
| 18 | 
         
            +
              model_path: ckpts/instant_nerf_base.ckpt
         
     | 
| 19 | 
         
            +
              mesh_threshold: 10.0
         
     | 
| 20 | 
         
            +
              mesh_resolution: 256
         
     | 
| 21 | 
         
            +
              render_resolution: 384
         
     | 
    	
        configs/instant-nerf-large.yaml
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_config:
         
     | 
| 2 | 
         
            +
              target: src.models.lrm.InstantNeRF
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                encoder_feat_dim: 768
         
     | 
| 5 | 
         
            +
                encoder_freeze: false
         
     | 
| 6 | 
         
            +
                encoder_model_name: facebook/dino-vitb16
         
     | 
| 7 | 
         
            +
                transformer_dim: 1024
         
     | 
| 8 | 
         
            +
                transformer_layers: 16
         
     | 
| 9 | 
         
            +
                transformer_heads: 16
         
     | 
| 10 | 
         
            +
                triplane_low_res: 32
         
     | 
| 11 | 
         
            +
                triplane_high_res: 64
         
     | 
| 12 | 
         
            +
                triplane_dim: 80
         
     | 
| 13 | 
         
            +
                rendering_samples_per_ray: 128
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            infer_config:
         
     | 
| 17 | 
         
            +
              unet_path: ckpts/diffusion_pytorch_model.bin
         
     | 
| 18 | 
         
            +
              model_path: ckpts/instant_nerf_large.ckpt
         
     | 
| 19 | 
         
            +
              mesh_threshold: 10.0
         
     | 
| 20 | 
         
            +
              mesh_resolution: 256
         
     | 
| 21 | 
         
            +
              render_resolution: 384
         
     | 
    	
        examples/bird.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/bubble_mart_blue.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/cake.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/cartoon_dinosaur.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/cartoon_girl.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/chair_comfort.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/chair_wood.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/chest.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/cube.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/extinguisher.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/fruit_bycycle.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/fruit_elephant.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/genshin_building.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/house2.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/kunkun.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/mushroom_teapot.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/pikachu.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/pistol.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/plant.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/robot.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/sea_turtle.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/skating_shoe.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/sorting_board.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/sword.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/toy_car.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/toyduck.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/watermelon.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/whitedog.png
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/x_cube.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/x_teapot.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        examples/x_toyduck.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            pytorch-lightning==2.1.2
         
     | 
| 2 | 
         
            +
            einops
         
     | 
| 3 | 
         
            +
            omegaconf
         
     | 
| 4 | 
         
            +
            deepspeed
         
     | 
| 5 | 
         
            +
            torchmetrics
         
     | 
| 6 | 
         
            +
            webdataset
         
     | 
| 7 | 
         
            +
            accelerate
         
     | 
| 8 | 
         
            +
            tensorboard
         
     | 
| 9 | 
         
            +
            PyMCubes
         
     | 
| 10 | 
         
            +
            trimesh
         
     | 
| 11 | 
         
            +
            rembg
         
     | 
| 12 | 
         
            +
            transformers==4.34.1
         
     | 
| 13 | 
         
            +
            diffusers==0.19.3
         
     | 
| 14 | 
         
            +
            bitsandbytes
         
     | 
| 15 | 
         
            +
            imageio[ffmpeg]
         
     | 
| 16 | 
         
            +
            xatlas
         
     | 
| 17 | 
         
            +
            plyfile
         
     | 
| 18 | 
         
            +
            xformers==0.0.22.post7
         
     | 
| 19 | 
         
            +
            git+https://github.com/NVlabs/nvdiffrast/
         
     | 
| 20 | 
         
            +
            torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
         
     | 
| 21 | 
         
            +
            huggingface-hub
         
     | 
    	
        src/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        src/data/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        src/data/objaverse.py
    ADDED
    
    | 
         @@ -0,0 +1,329 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os, sys
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            import json
         
     | 
| 4 | 
         
            +
            import importlib
         
     | 
| 5 | 
         
            +
            from pathlib import Path
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import cv2
         
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            from PIL import Image
         
     | 
| 11 | 
         
            +
            import webdataset as wds
         
     | 
| 12 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 17 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 18 | 
         
            +
            from torch.utils.data.distributed import DistributedSampler
         
     | 
| 19 | 
         
            +
            from torchvision import transforms
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from src.utils.train_util import instantiate_from_config
         
     | 
| 22 | 
         
            +
            from src.utils.camera_util import (
         
     | 
| 23 | 
         
            +
                FOV_to_intrinsics, 
         
     | 
| 24 | 
         
            +
                center_looking_at_camera_pose, 
         
     | 
| 25 | 
         
            +
                get_surrounding_views,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class DataModuleFromConfig(pl.LightningDataModule):
         
     | 
| 30 | 
         
            +
                def __init__(
         
     | 
| 31 | 
         
            +
                    self, 
         
     | 
| 32 | 
         
            +
                    batch_size=8, 
         
     | 
| 33 | 
         
            +
                    num_workers=4, 
         
     | 
| 34 | 
         
            +
                    train=None, 
         
     | 
| 35 | 
         
            +
                    validation=None, 
         
     | 
| 36 | 
         
            +
                    test=None, 
         
     | 
| 37 | 
         
            +
                    **kwargs,
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    super().__init__()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 42 | 
         
            +
                    self.num_workers = num_workers
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    self.dataset_configs = dict()
         
     | 
| 45 | 
         
            +
                    if train is not None:
         
     | 
| 46 | 
         
            +
                        self.dataset_configs['train'] = train
         
     | 
| 47 | 
         
            +
                    if validation is not None:
         
     | 
| 48 | 
         
            +
                        self.dataset_configs['validation'] = validation
         
     | 
| 49 | 
         
            +
                    if test is not None:
         
     | 
| 50 | 
         
            +
                        self.dataset_configs['test'] = test
         
     | 
| 51 | 
         
            +
                
         
     | 
| 52 | 
         
            +
                def setup(self, stage):
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    if stage in ['fit']:
         
     | 
| 55 | 
         
            +
                        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
         
     | 
| 56 | 
         
            +
                    else:
         
     | 
| 57 | 
         
            +
                        raise NotImplementedError
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def train_dataloader(self):
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    sampler = DistributedSampler(self.datasets['train'])
         
     | 
| 62 | 
         
            +
                    return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def val_dataloader(self):
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    sampler = DistributedSampler(self.datasets['validation'])
         
     | 
| 67 | 
         
            +
                    return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def test_dataloader(self):
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            class ObjaverseData(Dataset):
         
     | 
| 75 | 
         
            +
                def __init__(self,
         
     | 
| 76 | 
         
            +
                    root_dir='objaverse/',
         
     | 
| 77 | 
         
            +
                    meta_fname='valid_paths.json',
         
     | 
| 78 | 
         
            +
                    input_image_dir='rendering_random_32views',
         
     | 
| 79 | 
         
            +
                    target_image_dir='rendering_random_32views',
         
     | 
| 80 | 
         
            +
                    input_view_num=6,
         
     | 
| 81 | 
         
            +
                    target_view_num=2,
         
     | 
| 82 | 
         
            +
                    total_view_n=32,
         
     | 
| 83 | 
         
            +
                    fov=50,
         
     | 
| 84 | 
         
            +
                    camera_rotation=True,
         
     | 
| 85 | 
         
            +
                    validation=False,
         
     | 
| 86 | 
         
            +
                ):
         
     | 
| 87 | 
         
            +
                    self.root_dir = Path(root_dir)
         
     | 
| 88 | 
         
            +
                    self.input_image_dir = input_image_dir
         
     | 
| 89 | 
         
            +
                    self.target_image_dir = target_image_dir
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    self.input_view_num = input_view_num
         
     | 
| 92 | 
         
            +
                    self.target_view_num = target_view_num
         
     | 
| 93 | 
         
            +
                    self.total_view_n = total_view_n
         
     | 
| 94 | 
         
            +
                    self.fov = fov
         
     | 
| 95 | 
         
            +
                    self.camera_rotation = camera_rotation
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    with open(os.path.join(root_dir, meta_fname)) as f:
         
     | 
| 98 | 
         
            +
                        filtered_dict = json.load(f)
         
     | 
| 99 | 
         
            +
                    paths = filtered_dict['good_objs']
         
     | 
| 100 | 
         
            +
                    self.paths = paths
         
     | 
| 101 | 
         
            +
                    
         
     | 
| 102 | 
         
            +
                    self.depth_scale = 4.0
         
     | 
| 103 | 
         
            +
                        
         
     | 
| 104 | 
         
            +
                    total_objects = len(self.paths)
         
     | 
| 105 | 
         
            +
                    print('============= length of dataset %d =============' % len(self.paths))
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def __len__(self):
         
     | 
| 108 | 
         
            +
                    return len(self.paths)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def load_im(self, path, color):
         
     | 
| 111 | 
         
            +
                    '''
         
     | 
| 112 | 
         
            +
                    replace background pixel with random color in rendering
         
     | 
| 113 | 
         
            +
                    '''
         
     | 
| 114 | 
         
            +
                    pil_img = Image.open(path)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    image = np.asarray(pil_img, dtype=np.float32) / 255.
         
     | 
| 117 | 
         
            +
                    alpha = image[:, :, 3:]
         
     | 
| 118 | 
         
            +
                    image = image[:, :, :3] * alpha + color * (1 - alpha)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
         
     | 
| 121 | 
         
            +
                    alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
         
     | 
| 122 | 
         
            +
                    return image, alpha
         
     | 
| 123 | 
         
            +
                
         
     | 
| 124 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 125 | 
         
            +
                    # load data
         
     | 
| 126 | 
         
            +
                    while True:
         
     | 
| 127 | 
         
            +
                        input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
         
     | 
| 128 | 
         
            +
                        target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                        indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
         
     | 
| 131 | 
         
            +
                        input_indices = indices[:self.input_view_num]
         
     | 
| 132 | 
         
            +
                        target_indices = indices[self.input_view_num:]
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                        '''background color, default: white'''
         
     | 
| 135 | 
         
            +
                        bg_white = [1., 1., 1.]
         
     | 
| 136 | 
         
            +
                        bg_black = [0., 0., 0.]
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                        image_list = []
         
     | 
| 139 | 
         
            +
                        alpha_list = []
         
     | 
| 140 | 
         
            +
                        depth_list = []
         
     | 
| 141 | 
         
            +
                        normal_list = []
         
     | 
| 142 | 
         
            +
                        pose_list = []
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                        try:
         
     | 
| 145 | 
         
            +
                            input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
         
     | 
| 146 | 
         
            +
                            for idx in input_indices:
         
     | 
| 147 | 
         
            +
                                image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
         
     | 
| 148 | 
         
            +
                                normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
         
     | 
| 149 | 
         
            +
                                depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
         
     | 
| 150 | 
         
            +
                                depth = torch.from_numpy(depth).unsqueeze(0)
         
     | 
| 151 | 
         
            +
                                pose = input_cameras[idx]
         
     | 
| 152 | 
         
            +
                                pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                                image_list.append(image)
         
     | 
| 155 | 
         
            +
                                alpha_list.append(alpha)
         
     | 
| 156 | 
         
            +
                                depth_list.append(depth)
         
     | 
| 157 | 
         
            +
                                normal_list.append(normal)
         
     | 
| 158 | 
         
            +
                                pose_list.append(pose)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                            target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
         
     | 
| 161 | 
         
            +
                            for idx in target_indices:
         
     | 
| 162 | 
         
            +
                                image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
         
     | 
| 163 | 
         
            +
                                normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
         
     | 
| 164 | 
         
            +
                                depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
         
     | 
| 165 | 
         
            +
                                depth = torch.from_numpy(depth).unsqueeze(0)
         
     | 
| 166 | 
         
            +
                                pose = target_cameras[idx]
         
     | 
| 167 | 
         
            +
                                pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                                image_list.append(image)
         
     | 
| 170 | 
         
            +
                                alpha_list.append(alpha)
         
     | 
| 171 | 
         
            +
                                depth_list.append(depth)
         
     | 
| 172 | 
         
            +
                                normal_list.append(normal)
         
     | 
| 173 | 
         
            +
                                pose_list.append(pose)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        except Exception as e:
         
     | 
| 176 | 
         
            +
                            print(e)
         
     | 
| 177 | 
         
            +
                            index = np.random.randint(0, len(self.paths))
         
     | 
| 178 | 
         
            +
                            continue
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        break
         
     | 
| 181 | 
         
            +
                    
         
     | 
| 182 | 
         
            +
                    images = torch.stack(image_list, dim=0).float()                 # (6+V, 3, H, W)
         
     | 
| 183 | 
         
            +
                    alphas = torch.stack(alpha_list, dim=0).float()                 # (6+V, 1, H, W)
         
     | 
| 184 | 
         
            +
                    depths = torch.stack(depth_list, dim=0).float()                 # (6+V, 1, H, W)
         
     | 
| 185 | 
         
            +
                    normals = torch.stack(normal_list, dim=0).float()               # (6+V, 3, H, W)
         
     | 
| 186 | 
         
            +
                    w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float()    # (6+V, 4, 4)
         
     | 
| 187 | 
         
            +
                    c2ws = torch.linalg.inv(w2cs).float()
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    normals = normals * 2.0 - 1.0
         
     | 
| 190 | 
         
            +
                    normals = F.normalize(normals, dim=1)
         
     | 
| 191 | 
         
            +
                    normals = (normals + 1.0) / 2.0
         
     | 
| 192 | 
         
            +
                    normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # random rotation along z axis
         
     | 
| 195 | 
         
            +
                    if self.camera_rotation:
         
     | 
| 196 | 
         
            +
                        degree = np.random.uniform(0, math.pi * 2)
         
     | 
| 197 | 
         
            +
                        rot = torch.tensor([
         
     | 
| 198 | 
         
            +
                            [np.cos(degree), -np.sin(degree), 0, 0],
         
     | 
| 199 | 
         
            +
                            [np.sin(degree), np.cos(degree), 0, 0],
         
     | 
| 200 | 
         
            +
                            [0, 0, 1, 0],
         
     | 
| 201 | 
         
            +
                            [0, 0, 0, 1],
         
     | 
| 202 | 
         
            +
                        ]).unsqueeze(0).float()
         
     | 
| 203 | 
         
            +
                        c2ws = torch.matmul(rot, c2ws)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                        # rotate normals
         
     | 
| 206 | 
         
            +
                        N, _, H, W = normals.shape
         
     | 
| 207 | 
         
            +
                        normals = normals * 2.0 - 1.0
         
     | 
| 208 | 
         
            +
                        normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
         
     | 
| 209 | 
         
            +
                        normals = F.normalize(normals, dim=1)
         
     | 
| 210 | 
         
            +
                        normals = (normals + 1.0) / 2.0
         
     | 
| 211 | 
         
            +
                        normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    # random scaling
         
     | 
| 214 | 
         
            +
                    if np.random.rand() < 0.5:
         
     | 
| 215 | 
         
            +
                        scale = np.random.uniform(0.8, 1.0)
         
     | 
| 216 | 
         
            +
                        c2ws[:, :3, 3] *= scale
         
     | 
| 217 | 
         
            +
                        depths *= scale
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    # instrinsics of perspective cameras
         
     | 
| 220 | 
         
            +
                    K = FOV_to_intrinsics(self.fov)
         
     | 
| 221 | 
         
            +
                    Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    data = {
         
     | 
| 224 | 
         
            +
                        'input_images': images[:self.input_view_num],     # (6, 3, H, W)
         
     | 
| 225 | 
         
            +
                        'input_alphas': alphas[:self.input_view_num],           # (6, 1, H, W) 
         
     | 
| 226 | 
         
            +
                        'input_depths': depths[:self.input_view_num],           # (6, 1, H, W)
         
     | 
| 227 | 
         
            +
                        'input_normals': normals[:self.input_view_num],         # (6, 3, H, W)
         
     | 
| 228 | 
         
            +
                        'input_c2ws': c2ws_input[:self.input_view_num],         # (6, 4, 4)
         
     | 
| 229 | 
         
            +
                        'input_Ks': Ks[:self.input_view_num],                   # (6, 3, 3)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                        # lrm generator input and supervision
         
     | 
| 232 | 
         
            +
                        'target_images': images[self.input_view_num:],          # (V, 3, H, W)
         
     | 
| 233 | 
         
            +
                        'target_alphas': alphas[self.input_view_num:],          # (V, 1, H, W)
         
     | 
| 234 | 
         
            +
                        'target_depths': depths[self.input_view_num:],          # (V, 1, H, W)
         
     | 
| 235 | 
         
            +
                        'target_normals': normals[self.input_view_num:],        # (V, 3, H, W)
         
     | 
| 236 | 
         
            +
                        'target_c2ws': c2ws[self.input_view_num:],              # (V, 4, 4)
         
     | 
| 237 | 
         
            +
                        'target_Ks': Ks[self.input_view_num:],                  # (V, 3, 3)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                        'depth_available': 1,
         
     | 
| 240 | 
         
            +
                    }
         
     | 
| 241 | 
         
            +
                    return data
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
            class ValidationData(Dataset):
         
     | 
| 245 | 
         
            +
                def __init__(self,
         
     | 
| 246 | 
         
            +
                    root_dir='objaverse/',
         
     | 
| 247 | 
         
            +
                    input_view_num=6,
         
     | 
| 248 | 
         
            +
                    input_image_size=256,
         
     | 
| 249 | 
         
            +
                    fov=50,
         
     | 
| 250 | 
         
            +
                ):
         
     | 
| 251 | 
         
            +
                    self.root_dir = Path(root_dir)
         
     | 
| 252 | 
         
            +
                    self.input_view_num = input_view_num
         
     | 
| 253 | 
         
            +
                    self.input_image_size = input_image_size
         
     | 
| 254 | 
         
            +
                    self.fov = fov
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    self.paths = sorted(os.listdir(self.root_dir))
         
     | 
| 257 | 
         
            +
                    print('============= length of dataset %d =============' % len(self.paths))
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    cam_distance = 2.5
         
     | 
| 260 | 
         
            +
                    azimuths = np.array([30, 90, 150, 210, 270, 330])
         
     | 
| 261 | 
         
            +
                    elevations = np.array([30, -20, 30, -20, 30, -20])
         
     | 
| 262 | 
         
            +
                    azimuths = np.deg2rad(azimuths)
         
     | 
| 263 | 
         
            +
                    elevations = np.deg2rad(elevations)
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    x = cam_distance * np.cos(elevations) * np.cos(azimuths)
         
     | 
| 266 | 
         
            +
                    y = cam_distance * np.cos(elevations) * np.sin(azimuths)
         
     | 
| 267 | 
         
            +
                    z = cam_distance * np.sin(elevations)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    cam_locations = np.stack([x, y, z], axis=-1)
         
     | 
| 270 | 
         
            +
                    cam_locations = torch.from_numpy(cam_locations).float()
         
     | 
| 271 | 
         
            +
                    c2ws = center_looking_at_camera_pose(cam_locations)
         
     | 
| 272 | 
         
            +
                    self.c2ws = c2ws.float()
         
     | 
| 273 | 
         
            +
                    self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
         
     | 
| 276 | 
         
            +
                    render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
         
     | 
| 277 | 
         
            +
                    self.render_c2ws = render_c2ws.float()
         
     | 
| 278 | 
         
            +
                    self.render_Ks = render_Ks.float()
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def __len__(self):
         
     | 
| 281 | 
         
            +
                    return len(self.paths)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def load_im(self, path, color):
         
     | 
| 284 | 
         
            +
                    '''
         
     | 
| 285 | 
         
            +
                    replace background pixel with random color in rendering
         
     | 
| 286 | 
         
            +
                    '''
         
     | 
| 287 | 
         
            +
                    pil_img = Image.open(path)
         
     | 
| 288 | 
         
            +
                    pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    image = np.asarray(pil_img, dtype=np.float32) / 255.
         
     | 
| 291 | 
         
            +
                    if image.shape[-1] == 4:
         
     | 
| 292 | 
         
            +
                        alpha = image[:, :, 3:]
         
     | 
| 293 | 
         
            +
                        image = image[:, :, :3] * alpha + color * (1 - alpha)
         
     | 
| 294 | 
         
            +
                    else:
         
     | 
| 295 | 
         
            +
                        alpha = np.ones_like(image[:, :, :1])
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
         
     | 
| 298 | 
         
            +
                    alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
         
     | 
| 299 | 
         
            +
                    return image, alpha
         
     | 
| 300 | 
         
            +
                
         
     | 
| 301 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 302 | 
         
            +
                    # load data
         
     | 
| 303 | 
         
            +
                    input_image_path = os.path.join(self.root_dir, self.paths[index])
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    '''background color, default: white'''
         
     | 
| 306 | 
         
            +
                    # color = np.random.uniform(0.48, 0.52)
         
     | 
| 307 | 
         
            +
                    bkg_color = [1.0, 1.0, 1.0]
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    image_list = []
         
     | 
| 310 | 
         
            +
                    alpha_list = []
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                    for idx in range(self.input_view_num):
         
     | 
| 313 | 
         
            +
                        image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
         
     | 
| 314 | 
         
            +
                        image_list.append(image)
         
     | 
| 315 | 
         
            +
                        alpha_list.append(alpha)
         
     | 
| 316 | 
         
            +
                    
         
     | 
| 317 | 
         
            +
                    images = torch.stack(image_list, dim=0).float()                     # (6+V, 3, H, W)
         
     | 
| 318 | 
         
            +
                    alphas = torch.stack(alpha_list, dim=0).float()                 # (6+V, 1, H, W)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    data = {
         
     | 
| 321 | 
         
            +
                        'input_images': images,                 # (6, 3, H, W)
         
     | 
| 322 | 
         
            +
                        'input_alphas': alphas,             # (6, 1, H, W)
         
     | 
| 323 | 
         
            +
                        'input_c2ws': self.c2ws,            # (6, 4, 4)
         
     | 
| 324 | 
         
            +
                        'input_Ks': self.Ks,                # (6, 3, 3)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                        'render_c2ws': self.render_c2ws,
         
     | 
| 327 | 
         
            +
                        'render_Ks': self.render_Ks,
         
     | 
| 328 | 
         
            +
                    }
         
     | 
| 329 | 
         
            +
                    return data
         
     | 
    	
        src/model.py
    ADDED
    
    | 
         @@ -0,0 +1,310 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from torchvision.transforms import v2
         
     | 
| 6 | 
         
            +
            from torchvision.utils import make_grid, save_image
         
     | 
| 7 | 
         
            +
            from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
         
     | 
| 8 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 9 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from src.utils.train_util import instantiate_from_config
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class MVRecon(pl.LightningModule):
         
     | 
| 15 | 
         
            +
                def __init__(
         
     | 
| 16 | 
         
            +
                    self,
         
     | 
| 17 | 
         
            +
                    lrm_generator_config,
         
     | 
| 18 | 
         
            +
                    lrm_path=None,
         
     | 
| 19 | 
         
            +
                    input_size=256,
         
     | 
| 20 | 
         
            +
                    render_size=192,
         
     | 
| 21 | 
         
            +
                ):
         
     | 
| 22 | 
         
            +
                    super(MVRecon, self).__init__()
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.input_size = input_size
         
     | 
| 25 | 
         
            +
                    self.render_size = render_size
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    # init modules
         
     | 
| 28 | 
         
            +
                    self.lrm_generator = instantiate_from_config(lrm_generator_config)
         
     | 
| 29 | 
         
            +
                    if lrm_path is not None:
         
     | 
| 30 | 
         
            +
                        lrm_ckpt = torch.load(lrm_path)
         
     | 
| 31 | 
         
            +
                        self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
         
     | 
| 34 | 
         
            +
                    
         
     | 
| 35 | 
         
            +
                    self.validation_step_outputs = []
         
     | 
| 36 | 
         
            +
                
         
     | 
| 37 | 
         
            +
                def on_fit_start(self):
         
     | 
| 38 | 
         
            +
                    if self.global_rank == 0:
         
     | 
| 39 | 
         
            +
                        os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
         
     | 
| 40 | 
         
            +
                        os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
         
     | 
| 41 | 
         
            +
                
         
     | 
| 42 | 
         
            +
                def prepare_batch_data(self, batch):
         
     | 
| 43 | 
         
            +
                    lrm_generator_input = {}
         
     | 
| 44 | 
         
            +
                    render_gt = {}   # for supervision
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # input images
         
     | 
| 47 | 
         
            +
                    images = batch['input_images']
         
     | 
| 48 | 
         
            +
                    images = v2.functional.resize(
         
     | 
| 49 | 
         
            +
                        images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    lrm_generator_input['images'] = images.to(self.device)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    # input cameras and render cameras
         
     | 
| 54 | 
         
            +
                    input_c2ws = batch['input_c2ws'].flatten(-2)
         
     | 
| 55 | 
         
            +
                    input_Ks = batch['input_Ks'].flatten(-2)
         
     | 
| 56 | 
         
            +
                    target_c2ws = batch['target_c2ws'].flatten(-2)
         
     | 
| 57 | 
         
            +
                    target_Ks = batch['target_Ks'].flatten(-2)
         
     | 
| 58 | 
         
            +
                    render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
         
     | 
| 59 | 
         
            +
                    render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
         
     | 
| 60 | 
         
            +
                    render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    input_extrinsics = input_c2ws[:, :, :12]
         
     | 
| 63 | 
         
            +
                    input_intrinsics = torch.stack([
         
     | 
| 64 | 
         
            +
                        input_Ks[:, :, 0], input_Ks[:, :, 4], 
         
     | 
| 65 | 
         
            +
                        input_Ks[:, :, 2], input_Ks[:, :, 5],
         
     | 
| 66 | 
         
            +
                    ], dim=-1)
         
     | 
| 67 | 
         
            +
                    cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    # add noise to input cameras
         
     | 
| 70 | 
         
            +
                    cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    lrm_generator_input['cameras'] = cameras.to(self.device)
         
     | 
| 73 | 
         
            +
                    lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # target images
         
     | 
| 76 | 
         
            +
                    target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
         
     | 
| 77 | 
         
            +
                    target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
         
     | 
| 78 | 
         
            +
                    target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # random crop
         
     | 
| 81 | 
         
            +
                    render_size = np.random.randint(self.render_size, 513)
         
     | 
| 82 | 
         
            +
                    target_images = v2.functional.resize(
         
     | 
| 83 | 
         
            +
                        target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 84 | 
         
            +
                    target_depths = v2.functional.resize(
         
     | 
| 85 | 
         
            +
                        target_depths, render_size, interpolation=0, antialias=True)
         
     | 
| 86 | 
         
            +
                    target_alphas = v2.functional.resize(
         
     | 
| 87 | 
         
            +
                        target_alphas, render_size, interpolation=0, antialias=True)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    crop_params = v2.RandomCrop.get_params(
         
     | 
| 90 | 
         
            +
                        target_images, output_size=(self.render_size, self.render_size))
         
     | 
| 91 | 
         
            +
                    target_images = v2.functional.crop(target_images, *crop_params)
         
     | 
| 92 | 
         
            +
                    target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
         
     | 
| 93 | 
         
            +
                    target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    lrm_generator_input['render_size'] = render_size
         
     | 
| 96 | 
         
            +
                    lrm_generator_input['crop_params'] = crop_params
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    render_gt['target_images'] = target_images.to(self.device)
         
     | 
| 99 | 
         
            +
                    render_gt['target_depths'] = target_depths.to(self.device)
         
     | 
| 100 | 
         
            +
                    render_gt['target_alphas'] = target_alphas.to(self.device)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    return lrm_generator_input, render_gt
         
     | 
| 103 | 
         
            +
                
         
     | 
| 104 | 
         
            +
                def prepare_validation_batch_data(self, batch):
         
     | 
| 105 | 
         
            +
                    lrm_generator_input = {}
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    # input images
         
     | 
| 108 | 
         
            +
                    images = batch['input_images']
         
     | 
| 109 | 
         
            +
                    images = v2.functional.resize(
         
     | 
| 110 | 
         
            +
                        images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    lrm_generator_input['images'] = images.to(self.device)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    input_c2ws = batch['input_c2ws'].flatten(-2)
         
     | 
| 115 | 
         
            +
                    input_Ks = batch['input_Ks'].flatten(-2)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    input_extrinsics = input_c2ws[:, :, :12]
         
     | 
| 118 | 
         
            +
                    input_intrinsics = torch.stack([
         
     | 
| 119 | 
         
            +
                        input_Ks[:, :, 0], input_Ks[:, :, 4], 
         
     | 
| 120 | 
         
            +
                        input_Ks[:, :, 2], input_Ks[:, :, 5],
         
     | 
| 121 | 
         
            +
                    ], dim=-1)
         
     | 
| 122 | 
         
            +
                    cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    lrm_generator_input['cameras'] = cameras.to(self.device)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    render_c2ws = batch['render_c2ws'].flatten(-2)
         
     | 
| 127 | 
         
            +
                    render_Ks = batch['render_Ks'].flatten(-2)
         
     | 
| 128 | 
         
            +
                    render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
         
     | 
| 131 | 
         
            +
                    lrm_generator_input['render_size'] = 384
         
     | 
| 132 | 
         
            +
                    lrm_generator_input['crop_params'] = None
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    return lrm_generator_input
         
     | 
| 135 | 
         
            +
                
         
     | 
| 136 | 
         
            +
                def forward_lrm_generator(
         
     | 
| 137 | 
         
            +
                    self, 
         
     | 
| 138 | 
         
            +
                    images, 
         
     | 
| 139 | 
         
            +
                    cameras, 
         
     | 
| 140 | 
         
            +
                    render_cameras, 
         
     | 
| 141 | 
         
            +
                    render_size=192, 
         
     | 
| 142 | 
         
            +
                    crop_params=None, 
         
     | 
| 143 | 
         
            +
                    chunk_size=1,
         
     | 
| 144 | 
         
            +
                ):
         
     | 
| 145 | 
         
            +
                    planes = torch.utils.checkpoint.checkpoint(
         
     | 
| 146 | 
         
            +
                        self.lrm_generator.forward_planes, 
         
     | 
| 147 | 
         
            +
                        images, 
         
     | 
| 148 | 
         
            +
                        cameras, 
         
     | 
| 149 | 
         
            +
                        use_reentrant=False,
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
                    frames = []
         
     | 
| 152 | 
         
            +
                    for i in range(0, render_cameras.shape[1], chunk_size):
         
     | 
| 153 | 
         
            +
                        frames.append(
         
     | 
| 154 | 
         
            +
                            torch.utils.checkpoint.checkpoint(
         
     | 
| 155 | 
         
            +
                                self.lrm_generator.synthesizer,
         
     | 
| 156 | 
         
            +
                                planes,
         
     | 
| 157 | 
         
            +
                                cameras=render_cameras[:, i:i+chunk_size],
         
     | 
| 158 | 
         
            +
                                render_size=render_size, 
         
     | 
| 159 | 
         
            +
                                crop_params=crop_params,
         
     | 
| 160 | 
         
            +
                                use_reentrant=False
         
     | 
| 161 | 
         
            +
                            )
         
     | 
| 162 | 
         
            +
                        )
         
     | 
| 163 | 
         
            +
                    frames = {
         
     | 
| 164 | 
         
            +
                        k: torch.cat([r[k] for r in frames], dim=1)
         
     | 
| 165 | 
         
            +
                        for k in frames[0].keys()
         
     | 
| 166 | 
         
            +
                    }
         
     | 
| 167 | 
         
            +
                    return frames
         
     | 
| 168 | 
         
            +
                
         
     | 
| 169 | 
         
            +
                def forward(self, lrm_generator_input):
         
     | 
| 170 | 
         
            +
                    images = lrm_generator_input['images']
         
     | 
| 171 | 
         
            +
                    cameras = lrm_generator_input['cameras']
         
     | 
| 172 | 
         
            +
                    render_cameras = lrm_generator_input['render_cameras']
         
     | 
| 173 | 
         
            +
                    render_size = lrm_generator_input['render_size']
         
     | 
| 174 | 
         
            +
                    crop_params = lrm_generator_input['crop_params']
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    out = self.forward_lrm_generator(
         
     | 
| 177 | 
         
            +
                        images, 
         
     | 
| 178 | 
         
            +
                        cameras, 
         
     | 
| 179 | 
         
            +
                        render_cameras, 
         
     | 
| 180 | 
         
            +
                        render_size=render_size, 
         
     | 
| 181 | 
         
            +
                        crop_params=crop_params, 
         
     | 
| 182 | 
         
            +
                        chunk_size=1,
         
     | 
| 183 | 
         
            +
                    )
         
     | 
| 184 | 
         
            +
                    render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
         
     | 
| 185 | 
         
            +
                    render_depths = out['images_depth']
         
     | 
| 186 | 
         
            +
                    render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    out = {
         
     | 
| 189 | 
         
            +
                        'render_images': render_images,
         
     | 
| 190 | 
         
            +
                        'render_depths': render_depths,
         
     | 
| 191 | 
         
            +
                        'render_alphas': render_alphas,
         
     | 
| 192 | 
         
            +
                    }
         
     | 
| 193 | 
         
            +
                    return out
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def training_step(self, batch, batch_idx):
         
     | 
| 196 | 
         
            +
                    lrm_generator_input, render_gt = self.prepare_batch_data(batch)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    render_out = self.forward(lrm_generator_input)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    loss, loss_dict = self.compute_loss(render_out, render_gt)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    if self.global_step % 1000 == 0 and self.global_rank == 0:
         
     | 
| 205 | 
         
            +
                        B, N, C, H, W = render_gt['target_images'].shape
         
     | 
| 206 | 
         
            +
                        N_in = lrm_generator_input['images'].shape[1]
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                        input_images = v2.functional.resize(
         
     | 
| 209 | 
         
            +
                            lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 210 | 
         
            +
                        input_images = torch.cat(
         
     | 
| 211 | 
         
            +
                            [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                        input_images = rearrange(
         
     | 
| 214 | 
         
            +
                            input_images, 'b n c h w -> b c h (n w)')
         
     | 
| 215 | 
         
            +
                        target_images = rearrange(
         
     | 
| 216 | 
         
            +
                            render_gt['target_images'], 'b n c h w -> b c h (n w)')
         
     | 
| 217 | 
         
            +
                        render_images = rearrange(
         
     | 
| 218 | 
         
            +
                            render_out['render_images'], 'b n c h w -> b c h (n w)')
         
     | 
| 219 | 
         
            +
                        target_alphas = rearrange(
         
     | 
| 220 | 
         
            +
                            repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 221 | 
         
            +
                        render_alphas = rearrange(
         
     | 
| 222 | 
         
            +
                            repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 223 | 
         
            +
                        target_depths = rearrange(
         
     | 
| 224 | 
         
            +
                            repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 225 | 
         
            +
                        render_depths = rearrange(
         
     | 
| 226 | 
         
            +
                            repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 227 | 
         
            +
                        MAX_DEPTH = torch.max(target_depths)
         
     | 
| 228 | 
         
            +
                        target_depths = target_depths / MAX_DEPTH * target_alphas
         
     | 
| 229 | 
         
            +
                        render_depths = render_depths / MAX_DEPTH
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                        grid = torch.cat([
         
     | 
| 232 | 
         
            +
                            input_images, 
         
     | 
| 233 | 
         
            +
                            target_images, render_images, 
         
     | 
| 234 | 
         
            +
                            target_alphas, render_alphas, 
         
     | 
| 235 | 
         
            +
                            target_depths, render_depths,
         
     | 
| 236 | 
         
            +
                        ], dim=-2)
         
     | 
| 237 | 
         
            +
                        grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                        save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    return loss
         
     | 
| 242 | 
         
            +
                
         
     | 
| 243 | 
         
            +
                def compute_loss(self, render_out, render_gt):
         
     | 
| 244 | 
         
            +
                    # NOTE: the rgb value range of OpenLRM is [0, 1]
         
     | 
| 245 | 
         
            +
                    render_images = render_out['render_images']
         
     | 
| 246 | 
         
            +
                    target_images = render_gt['target_images'].to(render_images)
         
     | 
| 247 | 
         
            +
                    render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
         
     | 
| 248 | 
         
            +
                    target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    loss_mse = F.mse_loss(render_images, target_images)
         
     | 
| 251 | 
         
            +
                    loss_lpips = 2.0 * self.lpips(render_images, target_images)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    render_alphas = render_out['render_alphas']
         
     | 
| 254 | 
         
            +
                    target_alphas = render_gt['target_alphas']
         
     | 
| 255 | 
         
            +
                    loss_mask = F.mse_loss(render_alphas, target_alphas)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    loss = loss_mse + loss_lpips + loss_mask
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    prefix = 'train'
         
     | 
| 260 | 
         
            +
                    loss_dict = {}
         
     | 
| 261 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_mse': loss_mse})
         
     | 
| 262 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
         
     | 
| 263 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_mask': loss_mask})
         
     | 
| 264 | 
         
            +
                    loss_dict.update({f'{prefix}/loss': loss})
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    return loss, loss_dict
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                @torch.no_grad()
         
     | 
| 269 | 
         
            +
                def validation_step(self, batch, batch_idx):
         
     | 
| 270 | 
         
            +
                    lrm_generator_input = self.prepare_validation_batch_data(batch)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    render_out = self.forward(lrm_generator_input)
         
     | 
| 273 | 
         
            +
                    render_images = render_out['render_images']
         
     | 
| 274 | 
         
            +
                    render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    self.validation_step_outputs.append(render_images)
         
     | 
| 277 | 
         
            +
                
         
     | 
| 278 | 
         
            +
                def on_validation_epoch_end(self):
         
     | 
| 279 | 
         
            +
                    images = torch.cat(self.validation_step_outputs, dim=-1)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    all_images = self.all_gather(images)
         
     | 
| 282 | 
         
            +
                    all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    if self.global_rank == 0:
         
     | 
| 285 | 
         
            +
                        image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                        grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
         
     | 
| 288 | 
         
            +
                        save_image(grid, image_path)
         
     | 
| 289 | 
         
            +
                        print(f"Saved image to {image_path}")
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    self.validation_step_outputs.clear()
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def configure_optimizers(self):
         
     | 
| 294 | 
         
            +
                    lr = self.learning_rate
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    params = []
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    lrm_params_fast, lrm_params_slow = [], []
         
     | 
| 299 | 
         
            +
                    for n, p in self.lrm_generator.named_parameters():
         
     | 
| 300 | 
         
            +
                        if 'adaLN_modulation' in n or 'camera_embedder' in n:
         
     | 
| 301 | 
         
            +
                            lrm_params_fast.append(p)
         
     | 
| 302 | 
         
            +
                        else:
         
     | 
| 303 | 
         
            +
                            lrm_params_slow.append(p)
         
     | 
| 304 | 
         
            +
                    params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
         
     | 
| 305 | 
         
            +
                    params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
         
     | 
| 308 | 
         
            +
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    return {'optimizer': optimizer, 'lr_scheduler': scheduler}
         
     | 
    	
        src/model_mesh.py
    ADDED
    
    | 
         @@ -0,0 +1,325 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from torchvision.transforms import v2
         
     | 
| 6 | 
         
            +
            from torchvision.utils import make_grid, save_image
         
     | 
| 7 | 
         
            +
            from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
         
     | 
| 8 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 9 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from src.utils.train_util import instantiate_from_config
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # Regulrarization loss for FlexiCubes
         
     | 
| 15 | 
         
            +
            def sdf_reg_loss_batch(sdf, all_edges):
         
     | 
| 16 | 
         
            +
                sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
         
     | 
| 17 | 
         
            +
                mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
         
     | 
| 18 | 
         
            +
                sdf_f1x6x2 = sdf_f1x6x2[mask]
         
     | 
| 19 | 
         
            +
                sdf_diff = F.binary_cross_entropy_with_logits(
         
     | 
| 20 | 
         
            +
                    sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
         
     | 
| 21 | 
         
            +
                           F.binary_cross_entropy_with_logits(
         
     | 
| 22 | 
         
            +
                               sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
         
     | 
| 23 | 
         
            +
                return sdf_diff
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class MVRecon(pl.LightningModule):
         
     | 
| 27 | 
         
            +
                def __init__(
         
     | 
| 28 | 
         
            +
                    self,
         
     | 
| 29 | 
         
            +
                    lrm_generator_config,
         
     | 
| 30 | 
         
            +
                    input_size=256,
         
     | 
| 31 | 
         
            +
                    render_size=512,
         
     | 
| 32 | 
         
            +
                    init_ckpt=None,
         
     | 
| 33 | 
         
            +
                ):
         
     | 
| 34 | 
         
            +
                    super(MVRecon, self).__init__()
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    self.input_size = input_size
         
     | 
| 37 | 
         
            +
                    self.render_size = render_size
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    # init modules
         
     | 
| 40 | 
         
            +
                    self.lrm_generator = instantiate_from_config(lrm_generator_config)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    # Load weights from pretrained MVRecon model, and use the mlp 
         
     | 
| 45 | 
         
            +
                    # weights to initialize the weights of sdf and rgb mlps.
         
     | 
| 46 | 
         
            +
                    if init_ckpt is not None:
         
     | 
| 47 | 
         
            +
                        sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
         
     | 
| 48 | 
         
            +
                        sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
         
     | 
| 49 | 
         
            +
                        sd_fc = {}
         
     | 
| 50 | 
         
            +
                        for k, v in sd.items():
         
     | 
| 51 | 
         
            +
                            if k.startswith('lrm_generator.synthesizer.decoder.net.'):
         
     | 
| 52 | 
         
            +
                                if k.startswith('lrm_generator.synthesizer.decoder.net.6.'):    # last layer
         
     | 
| 53 | 
         
            +
                                    # Here we assume the density filed's isosurface threshold is t, 
         
     | 
| 54 | 
         
            +
                                    # we reverse the sign of density filed to initialize SDF field.  
         
     | 
| 55 | 
         
            +
                                    # -(w*x + b - t) = (-w)*x + (t - b)
         
     | 
| 56 | 
         
            +
                                    if 'weight' in k:
         
     | 
| 57 | 
         
            +
                                        sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
         
     | 
| 58 | 
         
            +
                                    else:
         
     | 
| 59 | 
         
            +
                                        sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
         
     | 
| 60 | 
         
            +
                                    sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
         
     | 
| 61 | 
         
            +
                                else:
         
     | 
| 62 | 
         
            +
                                    sd_fc[k.replace('net.', 'net_sdf.')] = v
         
     | 
| 63 | 
         
            +
                                    sd_fc[k.replace('net.', 'net_rgb.')] = v
         
     | 
| 64 | 
         
            +
                            else:
         
     | 
| 65 | 
         
            +
                                sd_fc[k] = v
         
     | 
| 66 | 
         
            +
                        sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
         
     | 
| 67 | 
         
            +
                        # missing `net_deformation` and `net_weight` parameters
         
     | 
| 68 | 
         
            +
                        self.lrm_generator.load_state_dict(sd_fc, strict=False)
         
     | 
| 69 | 
         
            +
                        print(f'Loaded weights from {init_ckpt}')
         
     | 
| 70 | 
         
            +
                    
         
     | 
| 71 | 
         
            +
                    self.validation_step_outputs = []
         
     | 
| 72 | 
         
            +
                
         
     | 
| 73 | 
         
            +
                def on_fit_start(self):
         
     | 
| 74 | 
         
            +
                    device = torch.device(f'cuda:{self.global_rank}')
         
     | 
| 75 | 
         
            +
                    self.lrm_generator.init_flexicubes_geometry(device)
         
     | 
| 76 | 
         
            +
                    if self.global_rank == 0:
         
     | 
| 77 | 
         
            +
                        os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
         
     | 
| 78 | 
         
            +
                        os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
         
     | 
| 79 | 
         
            +
                
         
     | 
| 80 | 
         
            +
                def prepare_batch_data(self, batch):
         
     | 
| 81 | 
         
            +
                    lrm_generator_input = {}
         
     | 
| 82 | 
         
            +
                    render_gt = {}
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # input images
         
     | 
| 85 | 
         
            +
                    images = batch['input_images']
         
     | 
| 86 | 
         
            +
                    images = v2.functional.resize(
         
     | 
| 87 | 
         
            +
                        images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    lrm_generator_input['images'] = images.to(self.device)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # input cameras and render cameras
         
     | 
| 92 | 
         
            +
                    input_c2ws = batch['input_c2ws']
         
     | 
| 93 | 
         
            +
                    input_Ks = batch['input_Ks']
         
     | 
| 94 | 
         
            +
                    target_c2ws = batch['target_c2ws']
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
         
     | 
| 97 | 
         
            +
                    render_w2cs = torch.linalg.inv(render_c2ws)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    input_extrinsics = input_c2ws.flatten(-2)
         
     | 
| 100 | 
         
            +
                    input_extrinsics = input_extrinsics[:, :, :12]
         
     | 
| 101 | 
         
            +
                    input_intrinsics = input_Ks.flatten(-2)
         
     | 
| 102 | 
         
            +
                    input_intrinsics = torch.stack([
         
     | 
| 103 | 
         
            +
                        input_intrinsics[:, :, 0], input_intrinsics[:, :, 4], 
         
     | 
| 104 | 
         
            +
                        input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
         
     | 
| 105 | 
         
            +
                    ], dim=-1)
         
     | 
| 106 | 
         
            +
                    cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    # add noise to input_cameras
         
     | 
| 109 | 
         
            +
                    cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    lrm_generator_input['cameras'] = cameras.to(self.device)
         
     | 
| 112 | 
         
            +
                    lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    # target images
         
     | 
| 115 | 
         
            +
                    target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
         
     | 
| 116 | 
         
            +
                    target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
         
     | 
| 117 | 
         
            +
                    target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
         
     | 
| 118 | 
         
            +
                    target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    render_size = self.render_size
         
     | 
| 121 | 
         
            +
                    target_images = v2.functional.resize(
         
     | 
| 122 | 
         
            +
                        target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 123 | 
         
            +
                    target_depths = v2.functional.resize(
         
     | 
| 124 | 
         
            +
                        target_depths, render_size, interpolation=0, antialias=True)
         
     | 
| 125 | 
         
            +
                    target_alphas = v2.functional.resize(
         
     | 
| 126 | 
         
            +
                        target_alphas, render_size, interpolation=0, antialias=True)
         
     | 
| 127 | 
         
            +
                    target_normals = v2.functional.resize(
         
     | 
| 128 | 
         
            +
                        target_normals, render_size, interpolation=3, antialias=True)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    lrm_generator_input['render_size'] = render_size
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    render_gt['target_images'] = target_images.to(self.device)
         
     | 
| 133 | 
         
            +
                    render_gt['target_depths'] = target_depths.to(self.device)
         
     | 
| 134 | 
         
            +
                    render_gt['target_alphas'] = target_alphas.to(self.device)
         
     | 
| 135 | 
         
            +
                    render_gt['target_normals'] = target_normals.to(self.device)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    return lrm_generator_input, render_gt
         
     | 
| 138 | 
         
            +
                
         
     | 
| 139 | 
         
            +
                def prepare_validation_batch_data(self, batch):
         
     | 
| 140 | 
         
            +
                    lrm_generator_input = {}
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    # input images
         
     | 
| 143 | 
         
            +
                    images = batch['input_images']
         
     | 
| 144 | 
         
            +
                    images = v2.functional.resize(
         
     | 
| 145 | 
         
            +
                        images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    lrm_generator_input['images'] = images.to(self.device)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    # input cameras
         
     | 
| 150 | 
         
            +
                    input_c2ws = batch['input_c2ws'].flatten(-2)
         
     | 
| 151 | 
         
            +
                    input_Ks = batch['input_Ks'].flatten(-2)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    input_extrinsics = input_c2ws[:, :, :12]
         
     | 
| 154 | 
         
            +
                    input_intrinsics = torch.stack([
         
     | 
| 155 | 
         
            +
                        input_Ks[:, :, 0], input_Ks[:, :, 4], 
         
     | 
| 156 | 
         
            +
                        input_Ks[:, :, 2], input_Ks[:, :, 5],
         
     | 
| 157 | 
         
            +
                    ], dim=-1)
         
     | 
| 158 | 
         
            +
                    cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    lrm_generator_input['cameras'] = cameras.to(self.device)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # render cameras
         
     | 
| 163 | 
         
            +
                    render_c2ws = batch['render_c2ws']
         
     | 
| 164 | 
         
            +
                    render_w2cs = torch.linalg.inv(render_c2ws)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
         
     | 
| 167 | 
         
            +
                    lrm_generator_input['render_size'] = 384
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    return lrm_generator_input
         
     | 
| 170 | 
         
            +
                
         
     | 
| 171 | 
         
            +
                def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
         
     | 
| 172 | 
         
            +
                    planes = torch.utils.checkpoint.checkpoint(
         
     | 
| 173 | 
         
            +
                        self.lrm_generator.forward_planes, 
         
     | 
| 174 | 
         
            +
                        images, 
         
     | 
| 175 | 
         
            +
                        cameras, 
         
     | 
| 176 | 
         
            +
                        use_reentrant=False,
         
     | 
| 177 | 
         
            +
                    )
         
     | 
| 178 | 
         
            +
                    out = self.lrm_generator.forward_geometry(
         
     | 
| 179 | 
         
            +
                        planes, 
         
     | 
| 180 | 
         
            +
                        render_cameras, 
         
     | 
| 181 | 
         
            +
                        render_size,
         
     | 
| 182 | 
         
            +
                    )
         
     | 
| 183 | 
         
            +
                    return out
         
     | 
| 184 | 
         
            +
                
         
     | 
| 185 | 
         
            +
                def forward(self, lrm_generator_input):
         
     | 
| 186 | 
         
            +
                    images = lrm_generator_input['images']
         
     | 
| 187 | 
         
            +
                    cameras = lrm_generator_input['cameras']
         
     | 
| 188 | 
         
            +
                    render_cameras = lrm_generator_input['render_cameras']
         
     | 
| 189 | 
         
            +
                    render_size = lrm_generator_input['render_size']
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    out = self.forward_lrm_generator(
         
     | 
| 192 | 
         
            +
                        images, cameras, render_cameras, render_size=render_size)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    return out
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def training_step(self, batch, batch_idx):
         
     | 
| 197 | 
         
            +
                    lrm_generator_input, render_gt = self.prepare_batch_data(batch)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    render_out = self.forward(lrm_generator_input)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    loss, loss_dict = self.compute_loss(render_out, render_gt)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    if self.global_step % 1000 == 0 and self.global_rank == 0:
         
     | 
| 206 | 
         
            +
                        B, N, C, H, W = render_gt['target_images'].shape
         
     | 
| 207 | 
         
            +
                        N_in = lrm_generator_input['images'].shape[1]
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                        target_images = rearrange(
         
     | 
| 210 | 
         
            +
                            render_gt['target_images'], 'b n c h w -> b c h (n w)')
         
     | 
| 211 | 
         
            +
                        render_images = rearrange(
         
     | 
| 212 | 
         
            +
                            render_out['img'], 'b n c h w -> b c h (n w)')
         
     | 
| 213 | 
         
            +
                        target_alphas = rearrange(
         
     | 
| 214 | 
         
            +
                            repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 215 | 
         
            +
                        render_alphas = rearrange(
         
     | 
| 216 | 
         
            +
                            repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 217 | 
         
            +
                        target_depths = rearrange(
         
     | 
| 218 | 
         
            +
                            repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 219 | 
         
            +
                        render_depths = rearrange(
         
     | 
| 220 | 
         
            +
                            repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
         
     | 
| 221 | 
         
            +
                        target_normals = rearrange(
         
     | 
| 222 | 
         
            +
                            render_gt['target_normals'], 'b n c h w -> b c h (n w)')
         
     | 
| 223 | 
         
            +
                        render_normals = rearrange(
         
     | 
| 224 | 
         
            +
                            render_out['normal'], 'b n c h w -> b c h (n w)')
         
     | 
| 225 | 
         
            +
                        MAX_DEPTH = torch.max(target_depths)
         
     | 
| 226 | 
         
            +
                        target_depths = target_depths / MAX_DEPTH * target_alphas
         
     | 
| 227 | 
         
            +
                        render_depths = render_depths / MAX_DEPTH
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                        grid = torch.cat([
         
     | 
| 230 | 
         
            +
                            target_images, render_images, 
         
     | 
| 231 | 
         
            +
                            target_alphas, render_alphas, 
         
     | 
| 232 | 
         
            +
                            target_depths, render_depths, 
         
     | 
| 233 | 
         
            +
                            target_normals, render_normals,
         
     | 
| 234 | 
         
            +
                        ], dim=-2)
         
     | 
| 235 | 
         
            +
                        grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                        image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
         
     | 
| 238 | 
         
            +
                        save_image(grid, image_path)
         
     | 
| 239 | 
         
            +
                        print(f"Saved image to {image_path}")
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    return loss
         
     | 
| 242 | 
         
            +
                
         
     | 
| 243 | 
         
            +
                def compute_loss(self, render_out, render_gt):
         
     | 
| 244 | 
         
            +
                    # NOTE: the rgb value range of OpenLRM is [0, 1]
         
     | 
| 245 | 
         
            +
                    render_images = render_out['img']
         
     | 
| 246 | 
         
            +
                    target_images = render_gt['target_images'].to(render_images)
         
     | 
| 247 | 
         
            +
                    render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
         
     | 
| 248 | 
         
            +
                    target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
         
     | 
| 249 | 
         
            +
                    loss_mse = F.mse_loss(render_images, target_images)
         
     | 
| 250 | 
         
            +
                    loss_lpips = 2.0 * self.lpips(render_images, target_images)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    render_alphas = render_out['mask']
         
     | 
| 253 | 
         
            +
                    target_alphas = render_gt['target_alphas']
         
     | 
| 254 | 
         
            +
                    loss_mask = F.mse_loss(render_alphas, target_alphas)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    render_depths = render_out['depth']
         
     | 
| 257 | 
         
            +
                    target_depths = render_gt['target_depths']
         
     | 
| 258 | 
         
            +
                    loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    render_normals = render_out['normal'] * 2.0 - 1.0
         
     | 
| 261 | 
         
            +
                    target_normals = render_gt['target_normals'] * 2.0 - 1.0
         
     | 
| 262 | 
         
            +
                    similarity = (render_normals * target_normals).sum(dim=-3).abs()
         
     | 
| 263 | 
         
            +
                    normal_mask = target_alphas.squeeze(-3)
         
     | 
| 264 | 
         
            +
                    loss_normal = 1 - similarity[normal_mask>0].mean()
         
     | 
| 265 | 
         
            +
                    loss_normal = 0.2 * loss_normal
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    # flexicubes regularization loss
         
     | 
| 268 | 
         
            +
                    sdf = render_out['sdf']
         
     | 
| 269 | 
         
            +
                    sdf_reg_loss = render_out['sdf_reg_loss']
         
     | 
| 270 | 
         
            +
                    sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
         
     | 
| 271 | 
         
            +
                    _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
         
     | 
| 272 | 
         
            +
                    flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
         
     | 
| 273 | 
         
            +
                    flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    prefix = 'train'
         
     | 
| 280 | 
         
            +
                    loss_dict = {}
         
     | 
| 281 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_mse': loss_mse})
         
     | 
| 282 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
         
     | 
| 283 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_mask': loss_mask})
         
     | 
| 284 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_normal': loss_normal})
         
     | 
| 285 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_depth': loss_depth})
         
     | 
| 286 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
         
     | 
| 287 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
         
     | 
| 288 | 
         
            +
                    loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
         
     | 
| 289 | 
         
            +
                    loss_dict.update({f'{prefix}/loss': loss})
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    return loss, loss_dict
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                @torch.no_grad()
         
     | 
| 294 | 
         
            +
                def validation_step(self, batch, batch_idx):
         
     | 
| 295 | 
         
            +
                    lrm_generator_input = self.prepare_validation_batch_data(batch)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    render_out = self.forward(lrm_generator_input)
         
     | 
| 298 | 
         
            +
                    render_images = render_out['img']
         
     | 
| 299 | 
         
            +
                    render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    self.validation_step_outputs.append(render_images)
         
     | 
| 302 | 
         
            +
                
         
     | 
| 303 | 
         
            +
                def on_validation_epoch_end(self):
         
     | 
| 304 | 
         
            +
                    images = torch.cat(self.validation_step_outputs, dim=-1)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    all_images = self.all_gather(images)
         
     | 
| 307 | 
         
            +
                    all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    if self.global_rank == 0:
         
     | 
| 310 | 
         
            +
                        image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                        grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
         
     | 
| 313 | 
         
            +
                        save_image(grid, image_path)
         
     | 
| 314 | 
         
            +
                        print(f"Saved image to {image_path}")
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    self.validation_step_outputs.clear()
         
     | 
| 317 | 
         
            +
                
         
     | 
| 318 | 
         
            +
                def configure_optimizers(self):
         
     | 
| 319 | 
         
            +
                    lr = self.learning_rate
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    optimizer = torch.optim.AdamW(
         
     | 
| 322 | 
         
            +
                        self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
         
     | 
| 323 | 
         
            +
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    return {'optimizer': optimizer, 'lr_scheduler': scheduler}
         
     | 
    	
        src/models/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        src/models/decoder/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        src/models/decoder/transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,123 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023, Zexin He
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn as nn
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class BasicTransformerBlock(nn.Module):
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
                Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                # use attention from torch.nn.MultiHeadAttention
         
     | 
| 25 | 
         
            +
                # Block contains a cross-attention layer, a self-attention layer, and a MLP
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self, 
         
     | 
| 28 | 
         
            +
                    inner_dim: int, 
         
     | 
| 29 | 
         
            +
                    cond_dim: int, 
         
     | 
| 30 | 
         
            +
                    num_heads: int, 
         
     | 
| 31 | 
         
            +
                    eps: float,
         
     | 
| 32 | 
         
            +
                    attn_drop: float = 0., 
         
     | 
| 33 | 
         
            +
                    attn_bias: bool = False,
         
     | 
| 34 | 
         
            +
                    mlp_ratio: float = 4., 
         
     | 
| 35 | 
         
            +
                    mlp_drop: float = 0.,
         
     | 
| 36 | 
         
            +
                ):
         
     | 
| 37 | 
         
            +
                    super().__init__()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.norm1 = nn.LayerNorm(inner_dim)
         
     | 
| 40 | 
         
            +
                    self.cross_attn = nn.MultiheadAttention(
         
     | 
| 41 | 
         
            +
                        embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
         
     | 
| 42 | 
         
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         
     | 
| 43 | 
         
            +
                    self.norm2 = nn.LayerNorm(inner_dim)
         
     | 
| 44 | 
         
            +
                    self.self_attn = nn.MultiheadAttention(
         
     | 
| 45 | 
         
            +
                        embed_dim=inner_dim, num_heads=num_heads,
         
     | 
| 46 | 
         
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         
     | 
| 47 | 
         
            +
                    self.norm3 = nn.LayerNorm(inner_dim)
         
     | 
| 48 | 
         
            +
                    self.mlp = nn.Sequential(
         
     | 
| 49 | 
         
            +
                        nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
         
     | 
| 50 | 
         
            +
                        nn.GELU(),
         
     | 
| 51 | 
         
            +
                        nn.Dropout(mlp_drop),
         
     | 
| 52 | 
         
            +
                        nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
         
     | 
| 53 | 
         
            +
                        nn.Dropout(mlp_drop),
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def forward(self, x, cond):
         
     | 
| 57 | 
         
            +
                    # x: [N, L, D]
         
     | 
| 58 | 
         
            +
                    # cond: [N, L_cond, D_cond]
         
     | 
| 59 | 
         
            +
                    x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
         
     | 
| 60 | 
         
            +
                    before_sa = self.norm2(x)
         
     | 
| 61 | 
         
            +
                    x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
         
     | 
| 62 | 
         
            +
                    x = x + self.mlp(self.norm3(x))
         
     | 
| 63 | 
         
            +
                    return x
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            class TriplaneTransformer(nn.Module):
         
     | 
| 67 | 
         
            +
                """
         
     | 
| 68 | 
         
            +
                Transformer with condition that generates a triplane representation.
         
     | 
| 69 | 
         
            +
                
         
     | 
| 70 | 
         
            +
                Reference:
         
     | 
| 71 | 
         
            +
                Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
                def __init__(
         
     | 
| 74 | 
         
            +
                    self, 
         
     | 
| 75 | 
         
            +
                    inner_dim: int, 
         
     | 
| 76 | 
         
            +
                    image_feat_dim: int,
         
     | 
| 77 | 
         
            +
                    triplane_low_res: int, 
         
     | 
| 78 | 
         
            +
                    triplane_high_res: int, 
         
     | 
| 79 | 
         
            +
                    triplane_dim: int,
         
     | 
| 80 | 
         
            +
                    num_layers: int, 
         
     | 
| 81 | 
         
            +
                    num_heads: int,
         
     | 
| 82 | 
         
            +
                    eps: float = 1e-6,
         
     | 
| 83 | 
         
            +
                ):
         
     | 
| 84 | 
         
            +
                    super().__init__()
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    # attributes
         
     | 
| 87 | 
         
            +
                    self.triplane_low_res = triplane_low_res
         
     | 
| 88 | 
         
            +
                    self.triplane_high_res = triplane_high_res
         
     | 
| 89 | 
         
            +
                    self.triplane_dim = triplane_dim
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # modules
         
     | 
| 92 | 
         
            +
                    # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
         
     | 
| 93 | 
         
            +
                    self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
         
     | 
| 94 | 
         
            +
                    self.layers = nn.ModuleList([
         
     | 
| 95 | 
         
            +
                        BasicTransformerBlock(
         
     | 
| 96 | 
         
            +
                            inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
         
     | 
| 97 | 
         
            +
                        for _ in range(num_layers)
         
     | 
| 98 | 
         
            +
                    ])
         
     | 
| 99 | 
         
            +
                    self.norm = nn.LayerNorm(inner_dim, eps=eps)
         
     | 
| 100 | 
         
            +
                    self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def forward(self, image_feats):
         
     | 
| 103 | 
         
            +
                    # image_feats: [N, L_cond, D_cond]
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    N = image_feats.shape[0]
         
     | 
| 106 | 
         
            +
                    H = W = self.triplane_low_res
         
     | 
| 107 | 
         
            +
                    L = 3 * H * W
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    x = self.pos_embed.repeat(N, 1, 1)  # [N, L, D]
         
     | 
| 110 | 
         
            +
                    for layer in self.layers:
         
     | 
| 111 | 
         
            +
                        x = layer(x, image_feats)
         
     | 
| 112 | 
         
            +
                    x = self.norm(x)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    # separate each plane and apply deconv
         
     | 
| 115 | 
         
            +
                    x = x.view(N, 3, H, W, -1)
         
     | 
| 116 | 
         
            +
                    x = torch.einsum('nihwd->indhw', x)  # [3, N, D, H, W]
         
     | 
| 117 | 
         
            +
                    x = x.contiguous().view(3*N, -1, H, W)  # [3*N, D, H, W]
         
     | 
| 118 | 
         
            +
                    x = self.deconv(x)  # [3*N, D', H', W']
         
     | 
| 119 | 
         
            +
                    x = x.view(3, N, *x.shape[-3:])  # [3, N, D', H', W']
         
     | 
| 120 | 
         
            +
                    x = torch.einsum('indhw->nidhw', x)  # [N, 3, D', H', W']
         
     | 
| 121 | 
         
            +
                    x = x.contiguous()
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    return x
         
     | 
    	
        src/models/encoder/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        src/models/encoder/dino.py
    ADDED
    
    | 
         @@ -0,0 +1,550 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
            """ PyTorch ViT model."""
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import collections.abc
         
     | 
| 19 | 
         
            +
            import math
         
     | 
| 20 | 
         
            +
            from typing import Dict, List, Optional, Set, Tuple, Union
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            import torch
         
     | 
| 23 | 
         
            +
            from torch import nn
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from transformers.activations import ACT2FN
         
     | 
| 26 | 
         
            +
            from transformers.modeling_outputs import (
         
     | 
| 27 | 
         
            +
                BaseModelOutput,
         
     | 
| 28 | 
         
            +
                BaseModelOutputWithPooling,
         
     | 
| 29 | 
         
            +
            )
         
     | 
| 30 | 
         
            +
            from transformers import PreTrainedModel, ViTConfig
         
     | 
| 31 | 
         
            +
            from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class ViTEmbeddings(nn.Module):
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
                Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
         
     | 
| 37 | 
         
            +
                """
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
         
     | 
| 40 | 
         
            +
                    super().__init__()
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
         
     | 
| 43 | 
         
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
         
     | 
| 44 | 
         
            +
                    self.patch_embeddings = ViTPatchEmbeddings(config)
         
     | 
| 45 | 
         
            +
                    num_patches = self.patch_embeddings.num_patches
         
     | 
| 46 | 
         
            +
                    self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
         
     | 
| 47 | 
         
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         
     | 
| 48 | 
         
            +
                    self.config = config
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
         
     | 
| 51 | 
         
            +
                    """
         
     | 
| 52 | 
         
            +
                    This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
         
     | 
| 53 | 
         
            +
                    resolution images.
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    Source:
         
     | 
| 56 | 
         
            +
                    https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
         
     | 
| 57 | 
         
            +
                    """
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    num_patches = embeddings.shape[1] - 1
         
     | 
| 60 | 
         
            +
                    num_positions = self.position_embeddings.shape[1] - 1
         
     | 
| 61 | 
         
            +
                    if num_patches == num_positions and height == width:
         
     | 
| 62 | 
         
            +
                        return self.position_embeddings
         
     | 
| 63 | 
         
            +
                    class_pos_embed = self.position_embeddings[:, 0]
         
     | 
| 64 | 
         
            +
                    patch_pos_embed = self.position_embeddings[:, 1:]
         
     | 
| 65 | 
         
            +
                    dim = embeddings.shape[-1]
         
     | 
| 66 | 
         
            +
                    h0 = height // self.config.patch_size
         
     | 
| 67 | 
         
            +
                    w0 = width // self.config.patch_size
         
     | 
| 68 | 
         
            +
                    # we add a small number to avoid floating point error in the interpolation
         
     | 
| 69 | 
         
            +
                    # see discussion at https://github.com/facebookresearch/dino/issues/8
         
     | 
| 70 | 
         
            +
                    h0, w0 = h0 + 0.1, w0 + 0.1
         
     | 
| 71 | 
         
            +
                    patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
         
     | 
| 72 | 
         
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
         
     | 
| 73 | 
         
            +
                    patch_pos_embed = nn.functional.interpolate(
         
     | 
| 74 | 
         
            +
                        patch_pos_embed,
         
     | 
| 75 | 
         
            +
                        scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
         
     | 
| 76 | 
         
            +
                        mode="bicubic",
         
     | 
| 77 | 
         
            +
                        align_corners=False,
         
     | 
| 78 | 
         
            +
                    )
         
     | 
| 79 | 
         
            +
                    assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
         
     | 
| 80 | 
         
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 81 | 
         
            +
                    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def forward(
         
     | 
| 84 | 
         
            +
                    self,
         
     | 
| 85 | 
         
            +
                    pixel_values: torch.Tensor,
         
     | 
| 86 | 
         
            +
                    bool_masked_pos: Optional[torch.BoolTensor] = None,
         
     | 
| 87 | 
         
            +
                    interpolate_pos_encoding: bool = False,
         
     | 
| 88 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 89 | 
         
            +
                    batch_size, num_channels, height, width = pixel_values.shape
         
     | 
| 90 | 
         
            +
                    embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if bool_masked_pos is not None:
         
     | 
| 93 | 
         
            +
                        seq_length = embeddings.shape[1]
         
     | 
| 94 | 
         
            +
                        mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
         
     | 
| 95 | 
         
            +
                        # replace the masked visual tokens by mask_tokens
         
     | 
| 96 | 
         
            +
                        mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
         
     | 
| 97 | 
         
            +
                        embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # add the [CLS] token to the embedded patch tokens
         
     | 
| 100 | 
         
            +
                    cls_tokens = self.cls_token.expand(batch_size, -1, -1)
         
     | 
| 101 | 
         
            +
                    embeddings = torch.cat((cls_tokens, embeddings), dim=1)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # add positional encoding to each token
         
     | 
| 104 | 
         
            +
                    if interpolate_pos_encoding:
         
     | 
| 105 | 
         
            +
                        embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
         
     | 
| 106 | 
         
            +
                    else:
         
     | 
| 107 | 
         
            +
                        embeddings = embeddings + self.position_embeddings
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    embeddings = self.dropout(embeddings)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    return embeddings
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            class ViTPatchEmbeddings(nn.Module):
         
     | 
| 115 | 
         
            +
                """
         
     | 
| 116 | 
         
            +
                This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
         
     | 
| 117 | 
         
            +
                `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
         
     | 
| 118 | 
         
            +
                Transformer.
         
     | 
| 119 | 
         
            +
                """
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def __init__(self, config):
         
     | 
| 122 | 
         
            +
                    super().__init__()
         
     | 
| 123 | 
         
            +
                    image_size, patch_size = config.image_size, config.patch_size
         
     | 
| 124 | 
         
            +
                    num_channels, hidden_size = config.num_channels, config.hidden_size
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
         
     | 
| 127 | 
         
            +
                    patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
         
     | 
| 128 | 
         
            +
                    num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
         
     | 
| 129 | 
         
            +
                    self.image_size = image_size
         
     | 
| 130 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 131 | 
         
            +
                    self.num_channels = num_channels
         
     | 
| 132 | 
         
            +
                    self.num_patches = num_patches
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
         
     | 
| 137 | 
         
            +
                    batch_size, num_channels, height, width = pixel_values.shape
         
     | 
| 138 | 
         
            +
                    if num_channels != self.num_channels:
         
     | 
| 139 | 
         
            +
                        raise ValueError(
         
     | 
| 140 | 
         
            +
                            "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
         
     | 
| 141 | 
         
            +
                            f" Expected {self.num_channels} but got {num_channels}."
         
     | 
| 142 | 
         
            +
                        )
         
     | 
| 143 | 
         
            +
                    if not interpolate_pos_encoding:
         
     | 
| 144 | 
         
            +
                        if height != self.image_size[0] or width != self.image_size[1]:
         
     | 
| 145 | 
         
            +
                            raise ValueError(
         
     | 
| 146 | 
         
            +
                                f"Input image size ({height}*{width}) doesn't match model"
         
     | 
| 147 | 
         
            +
                                f" ({self.image_size[0]}*{self.image_size[1]})."
         
     | 
| 148 | 
         
            +
                            )
         
     | 
| 149 | 
         
            +
                    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
         
     | 
| 150 | 
         
            +
                    return embeddings
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class ViTSelfAttention(nn.Module):
         
     | 
| 154 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 155 | 
         
            +
                    super().__init__()
         
     | 
| 156 | 
         
            +
                    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
         
     | 
| 157 | 
         
            +
                        raise ValueError(
         
     | 
| 158 | 
         
            +
                            f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
         
     | 
| 159 | 
         
            +
                            f"heads {config.num_attention_heads}."
         
     | 
| 160 | 
         
            +
                        )
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    self.num_attention_heads = config.num_attention_heads
         
     | 
| 163 | 
         
            +
                    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
         
     | 
| 164 | 
         
            +
                    self.all_head_size = self.num_attention_heads * self.attention_head_size
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
         
     | 
| 167 | 
         
            +
                    self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
         
     | 
| 168 | 
         
            +
                    self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 173 | 
         
            +
                    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
         
     | 
| 174 | 
         
            +
                    x = x.view(new_x_shape)
         
     | 
| 175 | 
         
            +
                    return x.permute(0, 2, 1, 3)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def forward(
         
     | 
| 178 | 
         
            +
                    self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
         
     | 
| 179 | 
         
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         
     | 
| 180 | 
         
            +
                    mixed_query_layer = self.query(hidden_states)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    key_layer = self.transpose_for_scores(self.key(hidden_states))
         
     | 
| 183 | 
         
            +
                    value_layer = self.transpose_for_scores(self.value(hidden_states))
         
     | 
| 184 | 
         
            +
                    query_layer = self.transpose_for_scores(mixed_query_layer)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # Take the dot product between "query" and "key" to get the raw attention scores.
         
     | 
| 187 | 
         
            +
                    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    # Normalize the attention scores to probabilities.
         
     | 
| 192 | 
         
            +
                    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # This is actually dropping out entire tokens to attend to, which might
         
     | 
| 195 | 
         
            +
                    # seem a bit unusual, but is taken from the original Transformer paper.
         
     | 
| 196 | 
         
            +
                    attention_probs = self.dropout(attention_probs)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    # Mask heads if we want to
         
     | 
| 199 | 
         
            +
                    if head_mask is not None:
         
     | 
| 200 | 
         
            +
                        attention_probs = attention_probs * head_mask
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    context_layer = torch.matmul(attention_probs, value_layer)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         
     | 
| 205 | 
         
            +
                    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         
     | 
| 206 | 
         
            +
                    context_layer = context_layer.view(new_context_layer_shape)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    return outputs
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            class ViTSelfOutput(nn.Module):
         
     | 
| 214 | 
         
            +
                """
         
     | 
| 215 | 
         
            +
                The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
         
     | 
| 216 | 
         
            +
                layernorm applied before each block.
         
     | 
| 217 | 
         
            +
                """
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 220 | 
         
            +
                    super().__init__()
         
     | 
| 221 | 
         
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         
     | 
| 222 | 
         
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
         
     | 
| 225 | 
         
            +
                    hidden_states = self.dense(hidden_states)
         
     | 
| 226 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    return hidden_states
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            class ViTAttention(nn.Module):
         
     | 
| 232 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 233 | 
         
            +
                    super().__init__()
         
     | 
| 234 | 
         
            +
                    self.attention = ViTSelfAttention(config)
         
     | 
| 235 | 
         
            +
                    self.output = ViTSelfOutput(config)
         
     | 
| 236 | 
         
            +
                    self.pruned_heads = set()
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def prune_heads(self, heads: Set[int]) -> None:
         
     | 
| 239 | 
         
            +
                    if len(heads) == 0:
         
     | 
| 240 | 
         
            +
                        return
         
     | 
| 241 | 
         
            +
                    heads, index = find_pruneable_heads_and_indices(
         
     | 
| 242 | 
         
            +
                        heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
         
     | 
| 243 | 
         
            +
                    )
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    # Prune linear layers
         
     | 
| 246 | 
         
            +
                    self.attention.query = prune_linear_layer(self.attention.query, index)
         
     | 
| 247 | 
         
            +
                    self.attention.key = prune_linear_layer(self.attention.key, index)
         
     | 
| 248 | 
         
            +
                    self.attention.value = prune_linear_layer(self.attention.value, index)
         
     | 
| 249 | 
         
            +
                    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    # Update hyper params and store pruned heads
         
     | 
| 252 | 
         
            +
                    self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
         
     | 
| 253 | 
         
            +
                    self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
         
     | 
| 254 | 
         
            +
                    self.pruned_heads = self.pruned_heads.union(heads)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def forward(
         
     | 
| 257 | 
         
            +
                    self,
         
     | 
| 258 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 259 | 
         
            +
                    head_mask: Optional[torch.Tensor] = None,
         
     | 
| 260 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 261 | 
         
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         
     | 
| 262 | 
         
            +
                    self_outputs = self.attention(hidden_states, head_mask, output_attentions)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    attention_output = self.output(self_outputs[0], hidden_states)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
         
     | 
| 267 | 
         
            +
                    return outputs
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            class ViTIntermediate(nn.Module):
         
     | 
| 271 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 272 | 
         
            +
                    super().__init__()
         
     | 
| 273 | 
         
            +
                    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
         
     | 
| 274 | 
         
            +
                    if isinstance(config.hidden_act, str):
         
     | 
| 275 | 
         
            +
                        self.intermediate_act_fn = ACT2FN[config.hidden_act]
         
     | 
| 276 | 
         
            +
                    else:
         
     | 
| 277 | 
         
            +
                        self.intermediate_act_fn = config.hidden_act
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         
     | 
| 280 | 
         
            +
                    hidden_states = self.dense(hidden_states)
         
     | 
| 281 | 
         
            +
                    hidden_states = self.intermediate_act_fn(hidden_states)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    return hidden_states
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
            class ViTOutput(nn.Module):
         
     | 
| 287 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 288 | 
         
            +
                    super().__init__()
         
     | 
| 289 | 
         
            +
                    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
         
     | 
| 290 | 
         
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
         
     | 
| 293 | 
         
            +
                    hidden_states = self.dense(hidden_states)
         
     | 
| 294 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    hidden_states = hidden_states + input_tensor
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    return hidden_states
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
            def modulate(x, shift, scale):
         
     | 
| 302 | 
         
            +
                return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            class ViTLayer(nn.Module):
         
     | 
| 306 | 
         
            +
                """This corresponds to the Block class in the timm implementation."""
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 309 | 
         
            +
                    super().__init__()
         
     | 
| 310 | 
         
            +
                    self.chunk_size_feed_forward = config.chunk_size_feed_forward
         
     | 
| 311 | 
         
            +
                    self.seq_len_dim = 1
         
     | 
| 312 | 
         
            +
                    self.attention = ViTAttention(config)
         
     | 
| 313 | 
         
            +
                    self.intermediate = ViTIntermediate(config)
         
     | 
| 314 | 
         
            +
                    self.output = ViTOutput(config)
         
     | 
| 315 | 
         
            +
                    self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 316 | 
         
            +
                    self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(
         
     | 
| 319 | 
         
            +
                        nn.SiLU(),
         
     | 
| 320 | 
         
            +
                        nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
         
     | 
| 321 | 
         
            +
                    )
         
     | 
| 322 | 
         
            +
                    nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
         
     | 
| 323 | 
         
            +
                    nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                def forward(
         
     | 
| 326 | 
         
            +
                    self,
         
     | 
| 327 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 328 | 
         
            +
                    adaln_input: torch.Tensor = None,
         
     | 
| 329 | 
         
            +
                    head_mask: Optional[torch.Tensor] = None,
         
     | 
| 330 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 331 | 
         
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         
     | 
| 332 | 
         
            +
                    shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    self_attention_outputs = self.attention(
         
     | 
| 335 | 
         
            +
                        modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa),  # in ViT, layernorm is applied before self-attention
         
     | 
| 336 | 
         
            +
                        head_mask,
         
     | 
| 337 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 338 | 
         
            +
                    )
         
     | 
| 339 | 
         
            +
                    attention_output = self_attention_outputs[0]
         
     | 
| 340 | 
         
            +
                    outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    # first residual connection
         
     | 
| 343 | 
         
            +
                    hidden_states = attention_output + hidden_states
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    # in ViT, layernorm is also applied after self-attention
         
     | 
| 346 | 
         
            +
                    layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
         
     | 
| 347 | 
         
            +
                    layer_output = self.intermediate(layer_output)
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    # second residual connection is done here
         
     | 
| 350 | 
         
            +
                    layer_output = self.output(layer_output, hidden_states)
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    outputs = (layer_output,) + outputs
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    return outputs
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
            class ViTEncoder(nn.Module):
         
     | 
| 358 | 
         
            +
                def __init__(self, config: ViTConfig) -> None:
         
     | 
| 359 | 
         
            +
                    super().__init__()
         
     | 
| 360 | 
         
            +
                    self.config = config
         
     | 
| 361 | 
         
            +
                    self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
         
     | 
| 362 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                def forward(
         
     | 
| 365 | 
         
            +
                    self,
         
     | 
| 366 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 367 | 
         
            +
                    adaln_input: torch.Tensor = None,
         
     | 
| 368 | 
         
            +
                    head_mask: Optional[torch.Tensor] = None,
         
     | 
| 369 | 
         
            +
                    output_attentions: bool = False,
         
     | 
| 370 | 
         
            +
                    output_hidden_states: bool = False,
         
     | 
| 371 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 372 | 
         
            +
                ) -> Union[tuple, BaseModelOutput]:
         
     | 
| 373 | 
         
            +
                    all_hidden_states = () if output_hidden_states else None
         
     | 
| 374 | 
         
            +
                    all_self_attentions = () if output_attentions else None
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    for i, layer_module in enumerate(self.layer):
         
     | 
| 377 | 
         
            +
                        if output_hidden_states:
         
     | 
| 378 | 
         
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                        layer_head_mask = head_mask[i] if head_mask is not None else None
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                        if self.gradient_checkpointing and self.training:
         
     | 
| 383 | 
         
            +
                            layer_outputs = self._gradient_checkpointing_func(
         
     | 
| 384 | 
         
            +
                                layer_module.__call__,
         
     | 
| 385 | 
         
            +
                                hidden_states,
         
     | 
| 386 | 
         
            +
                                adaln_input,
         
     | 
| 387 | 
         
            +
                                layer_head_mask,
         
     | 
| 388 | 
         
            +
                                output_attentions,
         
     | 
| 389 | 
         
            +
                            )
         
     | 
| 390 | 
         
            +
                        else:
         
     | 
| 391 | 
         
            +
                            layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                        hidden_states = layer_outputs[0]
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                        if output_attentions:
         
     | 
| 396 | 
         
            +
                            all_self_attentions = all_self_attentions + (layer_outputs[1],)
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                    if output_hidden_states:
         
     | 
| 399 | 
         
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    if not return_dict:
         
     | 
| 402 | 
         
            +
                        return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
         
     | 
| 403 | 
         
            +
                    return BaseModelOutput(
         
     | 
| 404 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 405 | 
         
            +
                        hidden_states=all_hidden_states,
         
     | 
| 406 | 
         
            +
                        attentions=all_self_attentions,
         
     | 
| 407 | 
         
            +
                    )
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
            class ViTPreTrainedModel(PreTrainedModel):
         
     | 
| 411 | 
         
            +
                """
         
     | 
| 412 | 
         
            +
                An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
         
     | 
| 413 | 
         
            +
                models.
         
     | 
| 414 | 
         
            +
                """
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                config_class = ViTConfig
         
     | 
| 417 | 
         
            +
                base_model_prefix = "vit"
         
     | 
| 418 | 
         
            +
                main_input_name = "pixel_values"
         
     | 
| 419 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 420 | 
         
            +
                _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
         
     | 
| 423 | 
         
            +
                    """Initialize the weights"""
         
     | 
| 424 | 
         
            +
                    if isinstance(module, (nn.Linear, nn.Conv2d)):
         
     | 
| 425 | 
         
            +
                        # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
         
     | 
| 426 | 
         
            +
                        # `trunc_normal_cpu` not implemented in `half` issues
         
     | 
| 427 | 
         
            +
                        module.weight.data = nn.init.trunc_normal_(
         
     | 
| 428 | 
         
            +
                            module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
         
     | 
| 429 | 
         
            +
                        ).to(module.weight.dtype)
         
     | 
| 430 | 
         
            +
                        if module.bias is not None:
         
     | 
| 431 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 432 | 
         
            +
                    elif isinstance(module, nn.LayerNorm):
         
     | 
| 433 | 
         
            +
                        module.bias.data.zero_()
         
     | 
| 434 | 
         
            +
                        module.weight.data.fill_(1.0)
         
     | 
| 435 | 
         
            +
                    elif isinstance(module, ViTEmbeddings):
         
     | 
| 436 | 
         
            +
                        module.position_embeddings.data = nn.init.trunc_normal_(
         
     | 
| 437 | 
         
            +
                            module.position_embeddings.data.to(torch.float32),
         
     | 
| 438 | 
         
            +
                            mean=0.0,
         
     | 
| 439 | 
         
            +
                            std=self.config.initializer_range,
         
     | 
| 440 | 
         
            +
                        ).to(module.position_embeddings.dtype)
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                        module.cls_token.data = nn.init.trunc_normal_(
         
     | 
| 443 | 
         
            +
                            module.cls_token.data.to(torch.float32),
         
     | 
| 444 | 
         
            +
                            mean=0.0,
         
     | 
| 445 | 
         
            +
                            std=self.config.initializer_range,
         
     | 
| 446 | 
         
            +
                        ).to(module.cls_token.dtype)
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
            class ViTModel(ViTPreTrainedModel):
         
     | 
| 450 | 
         
            +
                def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
         
     | 
| 451 | 
         
            +
                    super().__init__(config)
         
     | 
| 452 | 
         
            +
                    self.config = config
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
         
     | 
| 455 | 
         
            +
                    self.encoder = ViTEncoder(config)
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 458 | 
         
            +
                    self.pooler = ViTPooler(config) if add_pooling_layer else None
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 461 | 
         
            +
                    self.post_init()
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                def get_input_embeddings(self) -> ViTPatchEmbeddings:
         
     | 
| 464 | 
         
            +
                    return self.embeddings.patch_embeddings
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
         
     | 
| 467 | 
         
            +
                    """
         
     | 
| 468 | 
         
            +
                    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
         
     | 
| 469 | 
         
            +
                    class PreTrainedModel
         
     | 
| 470 | 
         
            +
                    """
         
     | 
| 471 | 
         
            +
                    for layer, heads in heads_to_prune.items():
         
     | 
| 472 | 
         
            +
                        self.encoder.layer[layer].attention.prune_heads(heads)
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                def forward(
         
     | 
| 475 | 
         
            +
                    self,
         
     | 
| 476 | 
         
            +
                    pixel_values: Optional[torch.Tensor] = None,
         
     | 
| 477 | 
         
            +
                    adaln_input: Optional[torch.Tensor] = None,
         
     | 
| 478 | 
         
            +
                    bool_masked_pos: Optional[torch.BoolTensor] = None,
         
     | 
| 479 | 
         
            +
                    head_mask: Optional[torch.Tensor] = None,
         
     | 
| 480 | 
         
            +
                    output_attentions: Optional[bool] = None,
         
     | 
| 481 | 
         
            +
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 482 | 
         
            +
                    interpolate_pos_encoding: Optional[bool] = None,
         
     | 
| 483 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 484 | 
         
            +
                ) -> Union[Tuple, BaseModelOutputWithPooling]:
         
     | 
| 485 | 
         
            +
                    r"""
         
     | 
| 486 | 
         
            +
                    bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
         
     | 
| 487 | 
         
            +
                        Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
         
     | 
| 488 | 
         
            +
                    """
         
     | 
| 489 | 
         
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 490 | 
         
            +
                    output_hidden_states = (
         
     | 
| 491 | 
         
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 492 | 
         
            +
                    )
         
     | 
| 493 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    if pixel_values is None:
         
     | 
| 496 | 
         
            +
                        raise ValueError("You have to specify pixel_values")
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    # Prepare head mask if needed
         
     | 
| 499 | 
         
            +
                    # 1.0 in head_mask indicate we keep the head
         
     | 
| 500 | 
         
            +
                    # attention_probs has shape bsz x n_heads x N x N
         
     | 
| 501 | 
         
            +
                    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
         
     | 
| 502 | 
         
            +
                    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
         
     | 
| 503 | 
         
            +
                    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
                    # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
         
     | 
| 506 | 
         
            +
                    expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
         
     | 
| 507 | 
         
            +
                    if pixel_values.dtype != expected_dtype:
         
     | 
| 508 | 
         
            +
                        pixel_values = pixel_values.to(expected_dtype)
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                    embedding_output = self.embeddings(
         
     | 
| 511 | 
         
            +
                        pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
         
     | 
| 512 | 
         
            +
                    )
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                    encoder_outputs = self.encoder(
         
     | 
| 515 | 
         
            +
                        embedding_output,
         
     | 
| 516 | 
         
            +
                        adaln_input=adaln_input,
         
     | 
| 517 | 
         
            +
                        head_mask=head_mask,
         
     | 
| 518 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 519 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 520 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 521 | 
         
            +
                    )
         
     | 
| 522 | 
         
            +
                    sequence_output = encoder_outputs[0]
         
     | 
| 523 | 
         
            +
                    sequence_output = self.layernorm(sequence_output)
         
     | 
| 524 | 
         
            +
                    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    if not return_dict:
         
     | 
| 527 | 
         
            +
                        head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
         
     | 
| 528 | 
         
            +
                        return head_outputs + encoder_outputs[1:]
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                    return BaseModelOutputWithPooling(
         
     | 
| 531 | 
         
            +
                        last_hidden_state=sequence_output,
         
     | 
| 532 | 
         
            +
                        pooler_output=pooled_output,
         
     | 
| 533 | 
         
            +
                        hidden_states=encoder_outputs.hidden_states,
         
     | 
| 534 | 
         
            +
                        attentions=encoder_outputs.attentions,
         
     | 
| 535 | 
         
            +
                    )
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
            class ViTPooler(nn.Module):
         
     | 
| 539 | 
         
            +
                def __init__(self, config: ViTConfig):
         
     | 
| 540 | 
         
            +
                    super().__init__()
         
     | 
| 541 | 
         
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         
     | 
| 542 | 
         
            +
                    self.activation = nn.Tanh()
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 545 | 
         
            +
                    # We "pool" the model by simply taking the hidden state corresponding
         
     | 
| 546 | 
         
            +
                    # to the first token.
         
     | 
| 547 | 
         
            +
                    first_token_tensor = hidden_states[:, 0]
         
     | 
| 548 | 
         
            +
                    pooled_output = self.dense(first_token_tensor)
         
     | 
| 549 | 
         
            +
                    pooled_output = self.activation(pooled_output)
         
     | 
| 550 | 
         
            +
                    return pooled_output
         
     | 
    	
        src/models/encoder/dino_wrapper.py
    ADDED
    
    | 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023, Zexin He
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch.nn as nn
         
     | 
| 17 | 
         
            +
            from transformers import ViTImageProcessor
         
     | 
| 18 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 19 | 
         
            +
            from .dino import ViTModel
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class DinoWrapper(nn.Module):
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                Dino v1 wrapper using huggingface transformer implementation.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                def __init__(self, model_name: str, freeze: bool = True):
         
     | 
| 27 | 
         
            +
                    super().__init__()
         
     | 
| 28 | 
         
            +
                    self.model, self.processor = self._build_dino(model_name)
         
     | 
| 29 | 
         
            +
                    self.camera_embedder = nn.Sequential(
         
     | 
| 30 | 
         
            +
                        nn.Linear(16, self.model.config.hidden_size, bias=True),
         
     | 
| 31 | 
         
            +
                        nn.SiLU(),
         
     | 
| 32 | 
         
            +
                        nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
                    if freeze:
         
     | 
| 35 | 
         
            +
                        self._freeze()
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def forward(self, image, camera):
         
     | 
| 38 | 
         
            +
                    # image: [B, N, C, H, W]
         
     | 
| 39 | 
         
            +
                    # camera: [B, N, D]
         
     | 
| 40 | 
         
            +
                    # RGB image with [0,1] scale and properly sized
         
     | 
| 41 | 
         
            +
                    if image.ndim == 5:
         
     | 
| 42 | 
         
            +
                        image = rearrange(image, 'b n c h w -> (b n) c h w')
         
     | 
| 43 | 
         
            +
                    dtype = image.dtype
         
     | 
| 44 | 
         
            +
                    inputs = self.processor(
         
     | 
| 45 | 
         
            +
                        images=image.float(), 
         
     | 
| 46 | 
         
            +
                        return_tensors="pt", 
         
     | 
| 47 | 
         
            +
                        do_rescale=False, 
         
     | 
| 48 | 
         
            +
                        do_resize=False,
         
     | 
| 49 | 
         
            +
                    ).to(self.model.device).to(dtype)
         
     | 
| 50 | 
         
            +
                    # embed camera
         
     | 
| 51 | 
         
            +
                    N = camera.shape[1]
         
     | 
| 52 | 
         
            +
                    camera_embeddings = self.camera_embedder(camera)
         
     | 
| 53 | 
         
            +
                    camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
         
     | 
| 54 | 
         
            +
                    embeddings = camera_embeddings
         
     | 
| 55 | 
         
            +
                    # This resampling of positional embedding uses bicubic interpolation
         
     | 
| 56 | 
         
            +
                    outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
         
     | 
| 57 | 
         
            +
                    last_hidden_states = outputs.last_hidden_state
         
     | 
| 58 | 
         
            +
                    return last_hidden_states
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def _freeze(self):
         
     | 
| 61 | 
         
            +
                    print(f"======== Freezing DinoWrapper ========")
         
     | 
| 62 | 
         
            +
                    self.model.eval()
         
     | 
| 63 | 
         
            +
                    for name, param in self.model.named_parameters():
         
     | 
| 64 | 
         
            +
                        param.requires_grad = False
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                @staticmethod
         
     | 
| 67 | 
         
            +
                def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
         
     | 
| 68 | 
         
            +
                    import requests
         
     | 
| 69 | 
         
            +
                    try:
         
     | 
| 70 | 
         
            +
                        model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
         
     | 
| 71 | 
         
            +
                        processor = ViTImageProcessor.from_pretrained(model_name)
         
     | 
| 72 | 
         
            +
                        return model, processor
         
     | 
| 73 | 
         
            +
                    except requests.exceptions.ProxyError as err:
         
     | 
| 74 | 
         
            +
                        if proxy_error_retries > 0:
         
     | 
| 75 | 
         
            +
                            print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
         
     | 
| 76 | 
         
            +
                            import time
         
     | 
| 77 | 
         
            +
                            time.sleep(proxy_error_cooldown)
         
     | 
| 78 | 
         
            +
                            return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
         
     | 
| 79 | 
         
            +
                        else:
         
     | 
| 80 | 
         
            +
                            raise err
         
     |