Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						bd86ed9
	
1
								Parent(s):
							
							db5b5dc
								
here we go
Browse files- .DS_Store +0 -0
- app.py +109 -0
- checkpoints/kittieigen_L.pth +3 -0
- checkpoints/nyu_L.pth +3 -0
- iebins/dataloaders/__init__.py +0 -0
- iebins/dataloaders/__pycache__/__init__.cpython-38.pyc +0 -0
- iebins/dataloaders/__pycache__/dataloader.cpython-38.pyc +0 -0
- iebins/dataloaders/__pycache__/dataloader_sun.cpython-38.pyc +0 -0
- iebins/dataloaders/dataloader.py +343 -0
- iebins/dataloaders/dataloader_sun.py +326 -0
- iebins/eval.py +177 -0
- iebins/eval_sun.py +179 -0
- iebins/inference_single_image.py +117 -0
- iebins/networks/NewCRFDepth.py +318 -0
- iebins/networks/__init__.py +0 -0
- iebins/networks/depth_update.py +39 -0
- iebins/networks/newcrf_layers.py +433 -0
- iebins/networks/newcrf_utils.py +264 -0
- iebins/networks/resize.py +51 -0
- iebins/networks/swin_transformer.py +620 -0
- iebins/networks/uper_crf_head.py +364 -0
- iebins/sum_depth.py +22 -0
- iebins/test.py +209 -0
- iebins/train.py +499 -0
- iebins/utils.py +356 -0
- iebins/utils/transfrom.py +250 -0
- requirements.txt +12 -0
    	
        .DS_Store
    ADDED
    
    | Binary file (6.15 kB). View file | 
|  | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,109 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import cv2
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import spaces
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            from torchvision.transforms import Compose
         | 
| 10 | 
            +
            import tempfile
         | 
| 11 | 
            +
            from gradio_imageslider import ImageSlider
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from iebins.networks.NewCRFDepth import NewCRFDepth
         | 
| 14 | 
            +
            from iebins.utils.transfrom import Resize, NormalizeImage, PrepareForNet
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            css = """
         | 
| 17 | 
            +
            #img-display-container {
         | 
| 18 | 
            +
                max-height: 100vh;
         | 
| 19 | 
            +
                }
         | 
| 20 | 
            +
            #img-display-input {
         | 
| 21 | 
            +
                max-height: 80vh;
         | 
| 22 | 
            +
                }
         | 
| 23 | 
            +
            #img-display-output {
         | 
| 24 | 
            +
                max-height: 80vh;
         | 
| 25 | 
            +
                }
         | 
| 26 | 
            +
            """
         | 
| 27 | 
            +
            DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 28 | 
            +
            model = NewCRFDepth(version="large07", inv_depth=False,
         | 
| 29 | 
            +
                                max_depth=10, pretrained=None).to(DEVICE).eval()
         | 
| 30 | 
            +
            model.load_state_dict(torch.load('checkpoints/nyu_L.pth'))
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
         | 
| 33 | 
            +
            description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
         | 
| 34 | 
            +
            Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            transform = Compose([
         | 
| 37 | 
            +
                Resize(
         | 
| 38 | 
            +
                    width=518,
         | 
| 39 | 
            +
                    height=518,
         | 
| 40 | 
            +
                    resize_target=False,
         | 
| 41 | 
            +
                    keep_aspect_ratio=True,
         | 
| 42 | 
            +
                    ensure_multiple_of=14,
         | 
| 43 | 
            +
                    resize_method='lower_bound',
         | 
| 44 | 
            +
                    image_interpolation_method=cv2.INTER_CUBIC,
         | 
| 45 | 
            +
                ),
         | 
| 46 | 
            +
                NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         | 
| 47 | 
            +
                PrepareForNet(),
         | 
| 48 | 
            +
            ])
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @spaces.GPU
         | 
| 52 | 
            +
            @torch.no_grad()
         | 
| 53 | 
            +
            def predict_depth(model, image):
         | 
| 54 | 
            +
                return model(image)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            with gr.Blocks(css=css) as demo:
         | 
| 58 | 
            +
                gr.Markdown(title)
         | 
| 59 | 
            +
                gr.Markdown(description)
         | 
| 60 | 
            +
                gr.Markdown("### Depth Prediction demo")
         | 
| 61 | 
            +
                gr.Markdown(
         | 
| 62 | 
            +
                    "You can slide the output to compare the depth prediction with input image")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                with gr.Row():
         | 
| 65 | 
            +
                    input_image = gr.Image(label="Input Image",
         | 
| 66 | 
            +
                                           type='numpy', elem_id='img-display-input')
         | 
| 67 | 
            +
                    depth_image_slider = ImageSlider(
         | 
| 68 | 
            +
                        label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
         | 
| 69 | 
            +
                raw_file = gr.File(
         | 
| 70 | 
            +
                    label="16-bit raw depth (can be considered as disparity)")
         | 
| 71 | 
            +
                submit = gr.Button("Submit")
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def on_submit(image):
         | 
| 74 | 
            +
                    original_image = image.copy()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    h, w = image.shape[:2]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
         | 
| 79 | 
            +
                    image = transform({'image': image})['image']
         | 
| 80 | 
            +
                    image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    depth = predict_depth(model, image)
         | 
| 83 | 
            +
                    depth = F.interpolate(depth[None], (h, w),
         | 
| 84 | 
            +
                                          mode='bilinear', align_corners=False)[0, 0]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
         | 
| 87 | 
            +
                    tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
         | 
| 88 | 
            +
                    raw_depth.save(tmp.name)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
         | 
| 91 | 
            +
                    depth = depth.cpu().numpy().astype(np.uint8)
         | 
| 92 | 
            +
                    colored_depth = cv2.applyColorMap(
         | 
| 93 | 
            +
                        depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    return [(original_image, colored_depth), tmp.name]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                submit.click(on_submit, inputs=[input_image], outputs=[
         | 
| 98 | 
            +
                             depth_image_slider, raw_file])
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                example_files = os.listdir('examples')
         | 
| 101 | 
            +
                example_files.sort()
         | 
| 102 | 
            +
                example_files = [os.path.join('examples', filename)
         | 
| 103 | 
            +
                                 for filename in example_files]
         | 
| 104 | 
            +
                examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
         | 
| 105 | 
            +
                                       depth_image_slider, raw_file], fn=on_submit, cache_examples=False)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            if __name__ == '__main__':
         | 
| 109 | 
            +
                demo.queue().launch()
         | 
    	
        checkpoints/kittieigen_L.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:bf10549a615b19b96ffdddc82e639662c421fe0cd30008cc3cf3e7d4bffa5f55
         | 
| 3 | 
            +
            size 3276188594
         | 
    	
        checkpoints/nyu_L.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:81d95d5f26f5d01b7e8b060467eef77ea6efea4ddf100d60f5fad87e6c0daae7
         | 
| 3 | 
            +
            size 3276188594
         | 
    	
        iebins/dataloaders/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        iebins/dataloaders/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | Binary file (173 Bytes). View file | 
|  | 
    	
        iebins/dataloaders/__pycache__/dataloader.cpython-38.pyc
    ADDED
    
    | Binary file (9.15 kB). View file | 
|  | 
    	
        iebins/dataloaders/__pycache__/dataloader_sun.cpython-38.pyc
    ADDED
    
    | Binary file (8.93 kB). View file | 
|  | 
    	
        iebins/dataloaders/dataloader.py
    ADDED
    
    | @@ -0,0 +1,343 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch.utils.data import Dataset, DataLoader
         | 
| 3 | 
            +
            import torch.utils.data.distributed
         | 
| 4 | 
            +
            from torchvision import transforms
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import random
         | 
| 10 | 
            +
            import copy
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from utils import DistributedSamplerNoEvenlyDivisible
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def _is_pil_image(img):
         | 
| 16 | 
            +
                return isinstance(img, Image.Image)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def _is_numpy_image(img):
         | 
| 20 | 
            +
                return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def preprocessing_transforms(mode):
         | 
| 24 | 
            +
                return transforms.Compose([
         | 
| 25 | 
            +
                    ToTensor(mode=mode)
         | 
| 26 | 
            +
                ])
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class NewDataLoader(object):
         | 
| 30 | 
            +
                def __init__(self, args, mode):
         | 
| 31 | 
            +
                    if mode == 'train':
         | 
| 32 | 
            +
                        self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 33 | 
            +
                        if args.distributed:
         | 
| 34 | 
            +
                            self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
         | 
| 35 | 
            +
                        else:
         | 
| 36 | 
            +
                            self.train_sampler = None
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                        self.data = DataLoader(self.training_samples, args.batch_size,
         | 
| 39 | 
            +
                                               shuffle=(self.train_sampler is None),
         | 
| 40 | 
            +
                                               num_workers=args.num_threads,
         | 
| 41 | 
            +
                                               pin_memory=True,
         | 
| 42 | 
            +
                                               sampler=self.train_sampler)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    elif mode == 'online_eval':
         | 
| 45 | 
            +
                        self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 46 | 
            +
                        if args.distributed:
         | 
| 47 | 
            +
                            # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
         | 
| 48 | 
            +
                            self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
         | 
| 49 | 
            +
                        else:
         | 
| 50 | 
            +
                            self.eval_sampler = None
         | 
| 51 | 
            +
                        self.data = DataLoader(self.testing_samples, 1,
         | 
| 52 | 
            +
                                               shuffle=False,
         | 
| 53 | 
            +
                                               num_workers=1,
         | 
| 54 | 
            +
                                               pin_memory=True,
         | 
| 55 | 
            +
                                               sampler=self.eval_sampler)
         | 
| 56 | 
            +
                    
         | 
| 57 | 
            +
                    elif mode == 'test':
         | 
| 58 | 
            +
                        self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 59 | 
            +
                        self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
         | 
| 63 | 
            +
                        
         | 
| 64 | 
            +
                        
         | 
| 65 | 
            +
            class DataLoadPreprocess(Dataset):
         | 
| 66 | 
            +
                def __init__(self, args, mode, transform=None, is_for_online_eval=False):
         | 
| 67 | 
            +
                    self.args = args
         | 
| 68 | 
            +
                    if mode == 'online_eval':
         | 
| 69 | 
            +
                        with open(args.filenames_file_eval, 'r') as f:
         | 
| 70 | 
            +
                            self.filenames = f.readlines()
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        with open(args.filenames_file, 'r') as f:
         | 
| 73 | 
            +
                            self.filenames = f.readlines()
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                    self.mode = mode
         | 
| 76 | 
            +
                    self.transform = transform
         | 
| 77 | 
            +
                    self.to_tensor = ToTensor
         | 
| 78 | 
            +
                    self.is_for_online_eval = is_for_online_eval
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                def __getitem__(self, idx):
         | 
| 81 | 
            +
                    sample_path = self.filenames[idx]
         | 
| 82 | 
            +
                    # focal = float(sample_path.split()[2])
         | 
| 83 | 
            +
                    focal = 518.8579
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if self.mode == 'train':
         | 
| 86 | 
            +
                        if self.args.dataset == 'kitti':
         | 
| 87 | 
            +
                            rgb_file = sample_path.split()[0]
         | 
| 88 | 
            +
                            depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
         | 
| 89 | 
            +
                            if self.args.use_right is True and random.random() > 0.5:
         | 
| 90 | 
            +
                                rgb_file = rgb_file.replace('image_02', 'image_03')
         | 
| 91 | 
            +
                                depth_file = depth_file.replace('image_02', 'image_03')
         | 
| 92 | 
            +
                        else:
         | 
| 93 | 
            +
                            rgb_file = sample_path.split()[0]
         | 
| 94 | 
            +
                            depth_file = sample_path.split()[1]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        image_path = os.path.join(self.args.data_path, rgb_file)
         | 
| 97 | 
            +
                        depth_path = os.path.join(self.args.gt_path, depth_file)
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                        image = Image.open(image_path)
         | 
| 100 | 
            +
                        depth_gt = Image.open(depth_path)
         | 
| 101 | 
            +
                        
         | 
| 102 | 
            +
                        if self.args.do_kb_crop is True:
         | 
| 103 | 
            +
                            height = image.height
         | 
| 104 | 
            +
                            width = image.width
         | 
| 105 | 
            +
                            top_margin = int(height - 352)
         | 
| 106 | 
            +
                            left_margin = int((width - 1216) / 2)
         | 
| 107 | 
            +
                            depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
         | 
| 108 | 
            +
                            image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
         | 
| 109 | 
            +
                        
         | 
| 110 | 
            +
                        # To avoid blank boundaries due to pixel registration
         | 
| 111 | 
            +
                        if self.args.dataset == 'nyu':
         | 
| 112 | 
            +
                            if self.args.input_height == 480:
         | 
| 113 | 
            +
                                depth_gt = np.array(depth_gt)
         | 
| 114 | 
            +
                                valid_mask = np.zeros_like(depth_gt)
         | 
| 115 | 
            +
                                valid_mask[45:472, 43:608] = 1
         | 
| 116 | 
            +
                                depth_gt[valid_mask==0] = 0
         | 
| 117 | 
            +
                                depth_gt = Image.fromarray(depth_gt)
         | 
| 118 | 
            +
                            else:
         | 
| 119 | 
            +
                                depth_gt = depth_gt.crop((43, 45, 608, 472))
         | 
| 120 | 
            +
                                image = image.crop((43, 45, 608, 472))
         | 
| 121 | 
            +
                
         | 
| 122 | 
            +
                        if self.args.do_random_rotate is True:
         | 
| 123 | 
            +
                            random_angle = (random.random() - 0.5) * 2 * self.args.degree
         | 
| 124 | 
            +
                            image = self.rotate_image(image, random_angle)
         | 
| 125 | 
            +
                            depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
         | 
| 126 | 
            +
                        
         | 
| 127 | 
            +
                        image = np.asarray(image, dtype=np.float32) / 255.0
         | 
| 128 | 
            +
                        depth_gt = np.asarray(depth_gt, dtype=np.float32)
         | 
| 129 | 
            +
                        depth_gt = np.expand_dims(depth_gt, axis=2)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        if self.args.dataset == 'nyu':
         | 
| 132 | 
            +
                            depth_gt = depth_gt / 1000.0
         | 
| 133 | 
            +
                        else:
         | 
| 134 | 
            +
                            depth_gt = depth_gt / 256.0
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
         | 
| 137 | 
            +
                            image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
         | 
| 138 | 
            +
                        image, depth_gt = self.train_preprocess(image, depth_gt)
         | 
| 139 | 
            +
                        # https://github.com/ShuweiShao/URCDC-Depth
         | 
| 140 | 
            +
                        image, depth_gt = self.Cut_Flip(image, depth_gt)
         | 
| 141 | 
            +
                        sample = {'image': image, 'depth': depth_gt, 'focal': focal}
         | 
| 142 | 
            +
                    
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        if self.mode == 'online_eval':
         | 
| 145 | 
            +
                            data_path = self.args.data_path_eval
         | 
| 146 | 
            +
                        else:
         | 
| 147 | 
            +
                            data_path = self.args.data_path
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        image_path = os.path.join(data_path, "./" + sample_path.split()[0])
         | 
| 150 | 
            +
                        image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                        if self.mode == 'online_eval':
         | 
| 153 | 
            +
                            gt_path = self.args.gt_path_eval
         | 
| 154 | 
            +
                            depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
         | 
| 155 | 
            +
                            if self.args.dataset == 'kitti':
         | 
| 156 | 
            +
                                depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
         | 
| 157 | 
            +
                            has_valid_depth = False
         | 
| 158 | 
            +
                            try:
         | 
| 159 | 
            +
                                depth_gt = Image.open(depth_path)
         | 
| 160 | 
            +
                                has_valid_depth = True
         | 
| 161 | 
            +
                            except IOError:
         | 
| 162 | 
            +
                                depth_gt = False
         | 
| 163 | 
            +
                                # print('Missing gt for {}'.format(image_path))
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                            if has_valid_depth:
         | 
| 166 | 
            +
                                depth_gt = np.asarray(depth_gt, dtype=np.float32)
         | 
| 167 | 
            +
                                depth_gt = np.expand_dims(depth_gt, axis=2)
         | 
| 168 | 
            +
                                if self.args.dataset == 'nyu':
         | 
| 169 | 
            +
                                    depth_gt = depth_gt / 1000.0
         | 
| 170 | 
            +
                                else:
         | 
| 171 | 
            +
                                    depth_gt = depth_gt / 256.0
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        if self.args.do_kb_crop is True:
         | 
| 174 | 
            +
                            height = image.shape[0]
         | 
| 175 | 
            +
                            width = image.shape[1]
         | 
| 176 | 
            +
                            top_margin = int(height - 352)
         | 
| 177 | 
            +
                            left_margin = int((width - 1216) / 2)
         | 
| 178 | 
            +
                            image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
         | 
| 179 | 
            +
                            if self.mode == 'online_eval' and has_valid_depth:
         | 
| 180 | 
            +
                                depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
         | 
| 181 | 
            +
                        
         | 
| 182 | 
            +
                        if self.mode == 'online_eval':
         | 
| 183 | 
            +
                            sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
         | 
| 184 | 
            +
                        else:
         | 
| 185 | 
            +
                            sample = {'image': image, 'focal': focal}
         | 
| 186 | 
            +
                    
         | 
| 187 | 
            +
                    if self.transform:
         | 
| 188 | 
            +
                        sample = self.transform([sample, self.args.dataset])
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                    return sample
         | 
| 191 | 
            +
                
         | 
| 192 | 
            +
                def rotate_image(self, image, angle, flag=Image.BILINEAR):
         | 
| 193 | 
            +
                    result = image.rotate(angle, resample=flag)
         | 
| 194 | 
            +
                    return result
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def random_crop(self, img, depth, height, width):
         | 
| 197 | 
            +
                    assert img.shape[0] >= height
         | 
| 198 | 
            +
                    assert img.shape[1] >= width
         | 
| 199 | 
            +
                    assert img.shape[0] == depth.shape[0]
         | 
| 200 | 
            +
                    assert img.shape[1] == depth.shape[1]
         | 
| 201 | 
            +
                    x = random.randint(0, img.shape[1] - width)
         | 
| 202 | 
            +
                    y = random.randint(0, img.shape[0] - height)
         | 
| 203 | 
            +
                    img = img[y:y + height, x:x + width, :]
         | 
| 204 | 
            +
                    depth = depth[y:y + height, x:x + width, :]
         | 
| 205 | 
            +
                    return img, depth
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def train_preprocess(self, image, depth_gt):
         | 
| 208 | 
            +
                    # Random flipping
         | 
| 209 | 
            +
                    do_flip = random.random()
         | 
| 210 | 
            +
                    if do_flip > 0.5:
         | 
| 211 | 
            +
                        image = (image[:, ::-1, :]).copy()
         | 
| 212 | 
            +
                        depth_gt = (depth_gt[:, ::-1, :]).copy()
         | 
| 213 | 
            +
                
         | 
| 214 | 
            +
                    # Random gamma, brightness, color augmentation
         | 
| 215 | 
            +
                    do_augment = random.random()
         | 
| 216 | 
            +
                    if do_augment > 0.5:
         | 
| 217 | 
            +
                        image = self.augment_image(image)
         | 
| 218 | 
            +
                
         | 
| 219 | 
            +
                    return image, depth_gt
         | 
| 220 | 
            +
                
         | 
| 221 | 
            +
                def augment_image(self, image):
         | 
| 222 | 
            +
                    # gamma augmentation
         | 
| 223 | 
            +
                    gamma = random.uniform(0.9, 1.1)
         | 
| 224 | 
            +
                    image_aug = image ** gamma
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    # brightness augmentation
         | 
| 227 | 
            +
                    if self.args.dataset == 'nyu':
         | 
| 228 | 
            +
                        brightness = random.uniform(0.75, 1.25)
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        brightness = random.uniform(0.9, 1.1)
         | 
| 231 | 
            +
                    image_aug = image_aug * brightness
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # color augmentation
         | 
| 234 | 
            +
                    colors = np.random.uniform(0.9, 1.1, size=3)
         | 
| 235 | 
            +
                    white = np.ones((image.shape[0], image.shape[1]))
         | 
| 236 | 
            +
                    color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
         | 
| 237 | 
            +
                    image_aug *= color_image
         | 
| 238 | 
            +
                    image_aug = np.clip(image_aug, 0, 1)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    return image_aug
         | 
| 241 | 
            +
                
         | 
| 242 | 
            +
                def Cut_Flip(self, image, depth):
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    p = random.random()
         | 
| 245 | 
            +
                    if p < 0.5:
         | 
| 246 | 
            +
                        return image, depth
         | 
| 247 | 
            +
                    image_copy = copy.deepcopy(image)
         | 
| 248 | 
            +
                    depth_copy = copy.deepcopy(depth)
         | 
| 249 | 
            +
                    h, w, c = image.shape
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    N = 2     
         | 
| 252 | 
            +
                    h_list = []
         | 
| 253 | 
            +
                    h_interval_list = []   # hight interval
         | 
| 254 | 
            +
                    for i in range(N-1):
         | 
| 255 | 
            +
                        h_list.append(random.randint(int(0.2*h), int(0.8*h)))
         | 
| 256 | 
            +
                    h_list.append(h)
         | 
| 257 | 
            +
                    h_list.append(0)  
         | 
| 258 | 
            +
                    h_list.sort()
         | 
| 259 | 
            +
                    h_list_inv = np.array([h]*(N+1))-np.array(h_list)
         | 
| 260 | 
            +
                    for i in range(len(h_list)-1):
         | 
| 261 | 
            +
                        h_interval_list.append(h_list[i+1]-h_list[i])
         | 
| 262 | 
            +
                    for i in range(N):
         | 
| 263 | 
            +
                        image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
         | 
| 264 | 
            +
                        depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    return image, depth
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                def __len__(self):
         | 
| 270 | 
            +
                    return len(self.filenames)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
             | 
| 273 | 
            +
            class ToTensor(object):
         | 
| 274 | 
            +
                def __init__(self, mode):
         | 
| 275 | 
            +
                    self.mode = mode
         | 
| 276 | 
            +
                    self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         | 
| 277 | 
            +
                
         | 
| 278 | 
            +
                def __call__(self, sample_dataset):
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    sample = sample_dataset[0]
         | 
| 281 | 
            +
                    dataset = sample_dataset[1]
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    image, focal = sample['image'], sample['focal']
         | 
| 284 | 
            +
                    image = self.to_tensor(image)
         | 
| 285 | 
            +
                    image = self.normalize(image)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    if dataset == 'kitti':
         | 
| 288 | 
            +
                        K_p = np.array([[716.88, 0, 596.5593, 0],
         | 
| 289 | 
            +
                              [0, 716.88, 149.854, 0],
         | 
| 290 | 
            +
                              [0, 0, 1, 0],
         | 
| 291 | 
            +
                              [0, 0, 0, 1]], dtype=np.float32)
         | 
| 292 | 
            +
                        inv_K_p = np.linalg.pinv(K_p)
         | 
| 293 | 
            +
                        inv_K_p = torch.from_numpy(inv_K_p)
         | 
| 294 | 
            +
                        
         | 
| 295 | 
            +
                    elif dataset == 'nyu':
         | 
| 296 | 
            +
                        K_p = np.array([[518.8579, 0, 325.5824, 0],
         | 
| 297 | 
            +
                              [0, 518.8579, 253.7362, 0],
         | 
| 298 | 
            +
                              [0, 0, 1, 0],
         | 
| 299 | 
            +
                              [0, 0, 0, 1]], dtype=np.float32)
         | 
| 300 | 
            +
                        inv_K_p = np.linalg.pinv(K_p)
         | 
| 301 | 
            +
                        inv_K_p = torch.from_numpy(inv_K_p)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    if self.mode == 'test':
         | 
| 304 | 
            +
                        return {'image': image, 'inv_K_p': inv_K_p, 'focal': focal}
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    depth = sample['depth']
         | 
| 307 | 
            +
                    if self.mode == 'train':
         | 
| 308 | 
            +
                        depth = self.to_tensor(depth)
         | 
| 309 | 
            +
                        return {'image': image, 'depth': depth, 'focal': focal}
         | 
| 310 | 
            +
                    else:
         | 
| 311 | 
            +
                        has_valid_depth = sample['has_valid_depth']
         | 
| 312 | 
            +
                        return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
         | 
| 313 | 
            +
                
         | 
| 314 | 
            +
                def to_tensor(self, pic):
         | 
| 315 | 
            +
                    if not (_is_pil_image(pic) or _is_numpy_image(pic)):
         | 
| 316 | 
            +
                        raise TypeError(
         | 
| 317 | 
            +
                            'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
         | 
| 318 | 
            +
                    
         | 
| 319 | 
            +
                    if isinstance(pic, np.ndarray):
         | 
| 320 | 
            +
                        img = torch.from_numpy(pic.transpose((2, 0, 1)))
         | 
| 321 | 
            +
                        return img
         | 
| 322 | 
            +
                    
         | 
| 323 | 
            +
                    # handle PIL Image
         | 
| 324 | 
            +
                    if pic.mode == 'I':
         | 
| 325 | 
            +
                        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
         | 
| 326 | 
            +
                    elif pic.mode == 'I;16':
         | 
| 327 | 
            +
                        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
         | 
| 328 | 
            +
                    else:
         | 
| 329 | 
            +
                        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
         | 
| 330 | 
            +
                    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
         | 
| 331 | 
            +
                    if pic.mode == 'YCbCr':
         | 
| 332 | 
            +
                        nchannel = 3
         | 
| 333 | 
            +
                    elif pic.mode == 'I;16':
         | 
| 334 | 
            +
                        nchannel = 1
         | 
| 335 | 
            +
                    else:
         | 
| 336 | 
            +
                        nchannel = len(pic.mode)
         | 
| 337 | 
            +
                    img = img.view(pic.size[1], pic.size[0], nchannel)
         | 
| 338 | 
            +
                    
         | 
| 339 | 
            +
                    img = img.transpose(0, 1).transpose(0, 2).contiguous()
         | 
| 340 | 
            +
                    if isinstance(img, torch.ByteTensor):
         | 
| 341 | 
            +
                        return img.float()
         | 
| 342 | 
            +
                    else:
         | 
| 343 | 
            +
                        return img
         | 
    	
        iebins/dataloaders/dataloader_sun.py
    ADDED
    
    | @@ -0,0 +1,326 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch.utils.data import Dataset, DataLoader
         | 
| 3 | 
            +
            import torch.utils.data.distributed
         | 
| 4 | 
            +
            from torchvision import transforms
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import random
         | 
| 10 | 
            +
            import copy
         | 
| 11 | 
            +
            import cv2
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from utils import DistributedSamplerNoEvenlyDivisible
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def _is_pil_image(img):
         | 
| 17 | 
            +
                return isinstance(img, Image.Image)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def _is_numpy_image(img):
         | 
| 21 | 
            +
                return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def preprocessing_transforms(mode):
         | 
| 25 | 
            +
                return transforms.Compose([
         | 
| 26 | 
            +
                    ToTensor(mode=mode)
         | 
| 27 | 
            +
                ])
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class NewDataLoader(object):
         | 
| 31 | 
            +
                def __init__(self, args, mode):
         | 
| 32 | 
            +
                    if mode == 'train':
         | 
| 33 | 
            +
                        self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 34 | 
            +
                        if args.distributed:
         | 
| 35 | 
            +
                            self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
         | 
| 36 | 
            +
                        else:
         | 
| 37 | 
            +
                            self.train_sampler = None
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                        self.data = DataLoader(self.training_samples, args.batch_size,
         | 
| 40 | 
            +
                                               shuffle=(self.train_sampler is None),
         | 
| 41 | 
            +
                                               num_workers=args.num_threads,
         | 
| 42 | 
            +
                                               pin_memory=True,
         | 
| 43 | 
            +
                                               sampler=self.train_sampler)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    elif mode == 'online_eval':
         | 
| 46 | 
            +
                        self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 47 | 
            +
                        if args.distributed:
         | 
| 48 | 
            +
                            # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
         | 
| 49 | 
            +
                            self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
         | 
| 50 | 
            +
                        else:
         | 
| 51 | 
            +
                            self.eval_sampler = None
         | 
| 52 | 
            +
                        self.data = DataLoader(self.testing_samples, 1,
         | 
| 53 | 
            +
                                               shuffle=False,
         | 
| 54 | 
            +
                                               num_workers=1,
         | 
| 55 | 
            +
                                               pin_memory=True,
         | 
| 56 | 
            +
                                               sampler=self.eval_sampler)
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    elif mode == 'test':
         | 
| 59 | 
            +
                        self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
         | 
| 60 | 
            +
                        self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
         | 
| 64 | 
            +
                        
         | 
| 65 | 
            +
                        
         | 
| 66 | 
            +
            class DataLoadPreprocess(Dataset):
         | 
| 67 | 
            +
                def __init__(self, args, mode, transform=None, is_for_online_eval=False):
         | 
| 68 | 
            +
                    self.args = args
         | 
| 69 | 
            +
                    if mode == 'online_eval':
         | 
| 70 | 
            +
                        with open(args.filenames_file_eval, 'r') as f:
         | 
| 71 | 
            +
                            self.filenames = f.readlines()
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        with open(args.filenames_file, 'r') as f:
         | 
| 74 | 
            +
                            self.filenames = f.readlines()
         | 
| 75 | 
            +
                
         | 
| 76 | 
            +
                    self.mode = mode
         | 
| 77 | 
            +
                    self.transform = transform
         | 
| 78 | 
            +
                    self.to_tensor = ToTensor
         | 
| 79 | 
            +
                    self.is_for_online_eval = is_for_online_eval
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                def __getitem__(self, idx):
         | 
| 82 | 
            +
                    sample_path = self.filenames[idx]
         | 
| 83 | 
            +
                    # focal = float(sample_path.split()[2])
         | 
| 84 | 
            +
                    focal = 518.8579
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    if self.mode == 'train':
         | 
| 87 | 
            +
                        if self.args.dataset == 'kitti':
         | 
| 88 | 
            +
                            rgb_file = sample_path.split()[0]
         | 
| 89 | 
            +
                            depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
         | 
| 90 | 
            +
                            if self.args.use_right is True and random.random() > 0.5:
         | 
| 91 | 
            +
                                rgb_file = rgb_file.replace('image_02', 'image_03')
         | 
| 92 | 
            +
                                depth_file = depth_file.replace('image_02', 'image_03')
         | 
| 93 | 
            +
                        else:
         | 
| 94 | 
            +
                            rgb_file = sample_path.split()[0]
         | 
| 95 | 
            +
                            depth_file = sample_path.split()[1]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        image_path = os.path.join(self.args.data_path, rgb_file)
         | 
| 98 | 
            +
                        depth_path = os.path.join(self.args.gt_path, depth_file)
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                        image = Image.open(image_path)
         | 
| 101 | 
            +
                        depth_gt = Image.open(depth_path)
         | 
| 102 | 
            +
                        
         | 
| 103 | 
            +
                        if self.args.do_kb_crop is True:
         | 
| 104 | 
            +
                            height = image.height
         | 
| 105 | 
            +
                            width = image.width
         | 
| 106 | 
            +
                            top_margin = int(height - 352)
         | 
| 107 | 
            +
                            left_margin = int((width - 1216) / 2)
         | 
| 108 | 
            +
                            depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
         | 
| 109 | 
            +
                            image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
         | 
| 110 | 
            +
                        
         | 
| 111 | 
            +
                        # To avoid blank boundaries due to pixel registration
         | 
| 112 | 
            +
                        if self.args.dataset == 'nyu':
         | 
| 113 | 
            +
                            if self.args.input_height == 480:
         | 
| 114 | 
            +
                                depth_gt = np.array(depth_gt)
         | 
| 115 | 
            +
                                valid_mask = np.zeros_like(depth_gt)
         | 
| 116 | 
            +
                                valid_mask[45:472, 43:608] = 1
         | 
| 117 | 
            +
                                depth_gt[valid_mask==0] = 0
         | 
| 118 | 
            +
                                depth_gt = Image.fromarray(depth_gt)
         | 
| 119 | 
            +
                            else:
         | 
| 120 | 
            +
                                depth_gt = depth_gt.crop((43, 45, 608, 472))
         | 
| 121 | 
            +
                                image = image.crop((43, 45, 608, 472))
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                        if self.args.do_random_rotate is True:
         | 
| 124 | 
            +
                            random_angle = (random.random() - 0.5) * 2 * self.args.degree
         | 
| 125 | 
            +
                            image = self.rotate_image(image, random_angle)
         | 
| 126 | 
            +
                            depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
         | 
| 127 | 
            +
                        
         | 
| 128 | 
            +
                        image = np.asarray(image, dtype=np.float32) / 255.0
         | 
| 129 | 
            +
                        depth_gt = np.asarray(depth_gt, dtype=np.float32)
         | 
| 130 | 
            +
                        depth_gt = np.expand_dims(depth_gt, axis=2)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        if self.args.dataset == 'nyu':
         | 
| 133 | 
            +
                            depth_gt = depth_gt / 1000.0
         | 
| 134 | 
            +
                        else:
         | 
| 135 | 
            +
                            depth_gt = depth_gt / 256.0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
         | 
| 138 | 
            +
                            image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
         | 
| 139 | 
            +
                        image, depth_gt = self.train_preprocess(image, depth_gt)
         | 
| 140 | 
            +
                        image, depth_gt = self.Cut_Flip(image, depth_gt)
         | 
| 141 | 
            +
                        sample = {'image': image, 'depth': depth_gt, 'focal': focal}
         | 
| 142 | 
            +
                    
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        if self.mode == 'online_eval':
         | 
| 145 | 
            +
                            data_path = self.args.data_path_eval
         | 
| 146 | 
            +
                        else:
         | 
| 147 | 
            +
                            data_path = self.args.data_path
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        image_path = os.path.join(data_path, "./" + sample_path.split()[0])
         | 
| 150 | 
            +
                        image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
         | 
| 151 | 
            +
                        image = cv2.resize(image, (640, 480))
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        if self.mode == 'online_eval':
         | 
| 154 | 
            +
                            gt_path = self.args.gt_path_eval
         | 
| 155 | 
            +
                            depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
         | 
| 156 | 
            +
                            if self.args.dataset == 'kitti':
         | 
| 157 | 
            +
                                depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
         | 
| 158 | 
            +
                            has_valid_depth = False
         | 
| 159 | 
            +
                            try:
         | 
| 160 | 
            +
                                depth_gt = Image.open(depth_path)
         | 
| 161 | 
            +
                                has_valid_depth = True
         | 
| 162 | 
            +
                            except IOError:
         | 
| 163 | 
            +
                                depth_gt = False
         | 
| 164 | 
            +
                                # print('Missing gt for {}'.format(image_path))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                            if has_valid_depth:
         | 
| 167 | 
            +
                                depth_gt = np.asarray(depth_gt, dtype=np.uint16) # 2
         | 
| 168 | 
            +
                                depth_gt = np.bitwise_or(np.right_shift(depth_gt, 3), np.left_shift(depth_gt, 16 - 3)) # 3
         | 
| 169 | 
            +
                                depth_gt = np.expand_dims(depth_gt, axis=2)
         | 
| 170 | 
            +
                                if self.args.dataset == 'nyu':
         | 
| 171 | 
            +
                                    depth_gt = depth_gt.astype(np.single) / 1000 # 4
         | 
| 172 | 
            +
                                    depth_gt = depth_gt.astype(np.float32) # 5
         | 
| 173 | 
            +
                                else:
         | 
| 174 | 
            +
                                    depth_gt = depth_gt / 256.0
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                        if self.args.do_kb_crop is True:
         | 
| 177 | 
            +
                            height = image.shape[0]
         | 
| 178 | 
            +
                            width = image.shape[1]
         | 
| 179 | 
            +
                            top_margin = int(height - 352)
         | 
| 180 | 
            +
                            left_margin = int((width - 1216) / 2)
         | 
| 181 | 
            +
                            image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
         | 
| 182 | 
            +
                            if self.mode == 'online_eval' and has_valid_depth:
         | 
| 183 | 
            +
                                depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
         | 
| 184 | 
            +
                        
         | 
| 185 | 
            +
                        if self.mode == 'online_eval':
         | 
| 186 | 
            +
                            sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
         | 
| 187 | 
            +
                        else:
         | 
| 188 | 
            +
                            sample = {'image': image, 'focal': focal}
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                    if self.transform:
         | 
| 191 | 
            +
                        sample = self.transform(sample)
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                    return sample
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                def rotate_image(self, image, angle, flag=Image.BILINEAR):
         | 
| 196 | 
            +
                    result = image.rotate(angle, resample=flag)
         | 
| 197 | 
            +
                    return result
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def random_crop(self, img, depth, height, width):
         | 
| 200 | 
            +
                    assert img.shape[0] >= height
         | 
| 201 | 
            +
                    assert img.shape[1] >= width
         | 
| 202 | 
            +
                    assert img.shape[0] == depth.shape[0]
         | 
| 203 | 
            +
                    assert img.shape[1] == depth.shape[1]
         | 
| 204 | 
            +
                    x = random.randint(0, img.shape[1] - width)
         | 
| 205 | 
            +
                    y = random.randint(0, img.shape[0] - height)
         | 
| 206 | 
            +
                    img = img[y:y + height, x:x + width, :]
         | 
| 207 | 
            +
                    depth = depth[y:y + height, x:x + width, :]
         | 
| 208 | 
            +
                    return img, depth
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def train_preprocess(self, image, depth_gt):
         | 
| 211 | 
            +
                    # Random flipping
         | 
| 212 | 
            +
                    do_flip = random.random()
         | 
| 213 | 
            +
                    if do_flip > 0.5:
         | 
| 214 | 
            +
                        image = (image[:, ::-1, :]).copy()
         | 
| 215 | 
            +
                        depth_gt = (depth_gt[:, ::-1, :]).copy()
         | 
| 216 | 
            +
                
         | 
| 217 | 
            +
                    # Random gamma, brightness, color augmentation
         | 
| 218 | 
            +
                    do_augment = random.random()
         | 
| 219 | 
            +
                    if do_augment > 0.5:
         | 
| 220 | 
            +
                        image = self.augment_image(image)
         | 
| 221 | 
            +
                
         | 
| 222 | 
            +
                    return image, depth_gt
         | 
| 223 | 
            +
                
         | 
| 224 | 
            +
                def augment_image(self, image):
         | 
| 225 | 
            +
                    # gamma augmentation
         | 
| 226 | 
            +
                    gamma = random.uniform(0.9, 1.1)
         | 
| 227 | 
            +
                    image_aug = image ** gamma
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # brightness augmentation
         | 
| 230 | 
            +
                    if self.args.dataset == 'nyu':
         | 
| 231 | 
            +
                        brightness = random.uniform(0.75, 1.25)
         | 
| 232 | 
            +
                    else:
         | 
| 233 | 
            +
                        brightness = random.uniform(0.9, 1.1)
         | 
| 234 | 
            +
                    image_aug = image_aug * brightness
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # color augmentation
         | 
| 237 | 
            +
                    colors = np.random.uniform(0.9, 1.1, size=3)
         | 
| 238 | 
            +
                    white = np.ones((image.shape[0], image.shape[1]))
         | 
| 239 | 
            +
                    color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
         | 
| 240 | 
            +
                    image_aug *= color_image
         | 
| 241 | 
            +
                    image_aug = np.clip(image_aug, 0, 1)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    return image_aug
         | 
| 244 | 
            +
                
         | 
| 245 | 
            +
                def Cut_Flip(self, image, depth):
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    p = random.random()
         | 
| 248 | 
            +
                    if p < 0.5:
         | 
| 249 | 
            +
                        return image, depth
         | 
| 250 | 
            +
                    image_copy = copy.deepcopy(image)
         | 
| 251 | 
            +
                    depth_copy = copy.deepcopy(depth)
         | 
| 252 | 
            +
                    h, w, c = image.shape
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    N = 2     
         | 
| 255 | 
            +
                    h_list = []
         | 
| 256 | 
            +
                    h_interval_list = []   # hight interval
         | 
| 257 | 
            +
                    for i in range(N-1):
         | 
| 258 | 
            +
                        h_list.append(random.randint(int(0.2*h), int(0.8*h)))
         | 
| 259 | 
            +
                    h_list.append(h)
         | 
| 260 | 
            +
                    h_list.append(0)  
         | 
| 261 | 
            +
                    h_list.sort()
         | 
| 262 | 
            +
                    h_list_inv = np.array([h]*(N+1))-np.array(h_list)
         | 
| 263 | 
            +
                    for i in range(len(h_list)-1):
         | 
| 264 | 
            +
                        h_interval_list.append(h_list[i+1]-h_list[i])
         | 
| 265 | 
            +
                    for i in range(N):
         | 
| 266 | 
            +
                        image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
         | 
| 267 | 
            +
                        depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    return image, depth
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                
         | 
| 272 | 
            +
                def __len__(self):
         | 
| 273 | 
            +
                    return len(self.filenames)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            class ToTensor(object):
         | 
| 277 | 
            +
                def __init__(self, mode):
         | 
| 278 | 
            +
                    self.mode = mode
         | 
| 279 | 
            +
                    self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         | 
| 280 | 
            +
                
         | 
| 281 | 
            +
                def __call__(self, sample):
         | 
| 282 | 
            +
                    image, focal = sample['image'], sample['focal']
         | 
| 283 | 
            +
                    image = self.to_tensor(image)
         | 
| 284 | 
            +
                    image = self.normalize(image)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    if self.mode == 'test':
         | 
| 287 | 
            +
                        return {'image': image, 'focal': focal}
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    depth = sample['depth']
         | 
| 290 | 
            +
                    if self.mode == 'train':
         | 
| 291 | 
            +
                        depth = self.to_tensor(depth)
         | 
| 292 | 
            +
                        return {'image': image, 'depth': depth, 'focal': focal}
         | 
| 293 | 
            +
                    else:
         | 
| 294 | 
            +
                        has_valid_depth = sample['has_valid_depth']
         | 
| 295 | 
            +
                        return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
         | 
| 296 | 
            +
                
         | 
| 297 | 
            +
                def to_tensor(self, pic):
         | 
| 298 | 
            +
                    if not (_is_pil_image(pic) or _is_numpy_image(pic)):
         | 
| 299 | 
            +
                        raise TypeError(
         | 
| 300 | 
            +
                            'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
         | 
| 301 | 
            +
                    
         | 
| 302 | 
            +
                    if isinstance(pic, np.ndarray):
         | 
| 303 | 
            +
                        img = torch.from_numpy(pic.transpose((2, 0, 1)))
         | 
| 304 | 
            +
                        return img
         | 
| 305 | 
            +
                    
         | 
| 306 | 
            +
                    # handle PIL Image
         | 
| 307 | 
            +
                    if pic.mode == 'I':
         | 
| 308 | 
            +
                        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
         | 
| 309 | 
            +
                    elif pic.mode == 'I;16':
         | 
| 310 | 
            +
                        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
         | 
| 313 | 
            +
                    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
         | 
| 314 | 
            +
                    if pic.mode == 'YCbCr':
         | 
| 315 | 
            +
                        nchannel = 3
         | 
| 316 | 
            +
                    elif pic.mode == 'I;16':
         | 
| 317 | 
            +
                        nchannel = 1
         | 
| 318 | 
            +
                    else:
         | 
| 319 | 
            +
                        nchannel = len(pic.mode)
         | 
| 320 | 
            +
                    img = img.view(pic.size[1], pic.size[0], nchannel)
         | 
| 321 | 
            +
                    
         | 
| 322 | 
            +
                    img = img.transpose(0, 1).transpose(0, 2).contiguous()
         | 
| 323 | 
            +
                    if isinstance(img, torch.ByteTensor):
         | 
| 324 | 
            +
                        return img.float()
         | 
| 325 | 
            +
                    else:
         | 
| 326 | 
            +
                        return img
         | 
    	
        iebins/eval.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os, sys
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from utils import post_process_depth, flip_lr, compute_errors
         | 
| 10 | 
            +
            from networks.NewCRFDepth import NewCRFDepth
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def convert_arg_line_to_args(arg_line):
         | 
| 14 | 
            +
                for arg in arg_line.split():
         | 
| 15 | 
            +
                    if not arg.strip():
         | 
| 16 | 
            +
                        continue
         | 
| 17 | 
            +
                    yield arg
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
         | 
| 21 | 
            +
            parser.convert_arg_line_to_args = convert_arg_line_to_args
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            parser.add_argument('--model_name',                type=str,   help='model name', default='iebins')
         | 
| 24 | 
            +
            parser.add_argument('--encoder',                   type=str,   help='type of encoder, base07, large07, tiny07', default='large07')
         | 
| 25 | 
            +
            parser.add_argument('--checkpoint_path',           type=str,   help='path to a checkpoint to load', default='')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Dataset
         | 
| 28 | 
            +
            parser.add_argument('--dataset',                   type=str,   help='dataset to train on, kitti or nyu', default='nyu')
         | 
| 29 | 
            +
            parser.add_argument('--input_height',              type=int,   help='input height', default=480)
         | 
| 30 | 
            +
            parser.add_argument('--input_width',               type=int,   help='input width',  default=640)
         | 
| 31 | 
            +
            parser.add_argument('--max_depth',                 type=float, help='maximum depth in estimation', default=10)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Preprocessing
         | 
| 34 | 
            +
            parser.add_argument('--do_random_rotate',                      help='if set, will perform random rotation for augmentation', action='store_true')
         | 
| 35 | 
            +
            parser.add_argument('--degree',                    type=float, help='random rotation maximum degree', default=2.5)
         | 
| 36 | 
            +
            parser.add_argument('--do_kb_crop',                            help='if set, crop input images as kitti benchmark images', action='store_true')
         | 
| 37 | 
            +
            parser.add_argument('--use_right',                             help='if set, will randomly use right images when train on KITTI', action='store_true')
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Eval
         | 
| 40 | 
            +
            parser.add_argument('--data_path_eval',            type=str,   help='path to the data for evaluation', required=False)
         | 
| 41 | 
            +
            parser.add_argument('--gt_path_eval',              type=str,   help='path to the groundtruth data for evaluation', required=False)
         | 
| 42 | 
            +
            parser.add_argument('--filenames_file_eval',       type=str,   help='path to the filenames text file for evaluation', required=False)
         | 
| 43 | 
            +
            parser.add_argument('--min_depth_eval',            type=float, help='minimum depth for evaluation', default=1e-3)
         | 
| 44 | 
            +
            parser.add_argument('--max_depth_eval',            type=float, help='maximum depth for evaluation', default=80)
         | 
| 45 | 
            +
            parser.add_argument('--eigen_crop',                            help='if set, crops according to Eigen NIPS14', action='store_true')
         | 
| 46 | 
            +
            parser.add_argument('--garg_crop',                             help='if set, crops according to Garg  ECCV16', action='store_true')
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            if sys.argv.__len__() == 2:
         | 
| 50 | 
            +
                arg_filename_with_prefix = '@' + sys.argv[1]
         | 
| 51 | 
            +
                args = parser.parse_args([arg_filename_with_prefix])
         | 
| 52 | 
            +
            else:
         | 
| 53 | 
            +
                args = parser.parse_args()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            if args.dataset == 'kitti' or args.dataset == 'nyu':
         | 
| 56 | 
            +
                from dataloaders.dataloader import NewDataLoader
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def eval(model, dataloader_eval, post_process=False):
         | 
| 60 | 
            +
                eval_measures = torch.zeros(10).cuda()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
         | 
| 63 | 
            +
                    with torch.no_grad():
         | 
| 64 | 
            +
                        image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
         | 
| 65 | 
            +
                        gt_depth = eval_sample_batched['depth']
         | 
| 66 | 
            +
                        has_valid_depth = eval_sample_batched['has_valid_depth']
         | 
| 67 | 
            +
                        if not has_valid_depth:
         | 
| 68 | 
            +
                            # print('Invalid depth. continue.')
         | 
| 69 | 
            +
                            continue
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                        pred_depths_r_list, _, _ = model(image)
         | 
| 72 | 
            +
                        if post_process:
         | 
| 73 | 
            +
                            image_flipped = flip_lr(image)
         | 
| 74 | 
            +
                            pred_depths_r_list_flipped, _, _ = model(image_flipped)
         | 
| 75 | 
            +
                            pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        pred_depth = pred_depth.cpu().numpy().squeeze()
         | 
| 78 | 
            +
                        gt_depth = gt_depth.cpu().numpy().squeeze()     
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if args.do_kb_crop:
         | 
| 81 | 
            +
                        height, width = gt_depth.shape
         | 
| 82 | 
            +
                        top_margin = int(height - 352)
         | 
| 83 | 
            +
                        left_margin = int((width - 1216) / 2)
         | 
| 84 | 
            +
                        pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
         | 
| 85 | 
            +
                        pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
         | 
| 86 | 
            +
                        pred_depth = pred_depth_uncropped
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
         | 
| 89 | 
            +
                    pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
         | 
| 90 | 
            +
                    pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
         | 
| 91 | 
            +
                    pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if args.garg_crop or args.eigen_crop:
         | 
| 96 | 
            +
                        gt_height, gt_width = gt_depth.shape
         | 
| 97 | 
            +
                        eval_mask = np.zeros(valid_mask.shape)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        if args.garg_crop:
         | 
| 100 | 
            +
                            eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        elif args.eigen_crop:
         | 
| 103 | 
            +
                            if args.dataset == 'kitti':
         | 
| 104 | 
            +
                                eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 105 | 
            +
                            elif args.dataset == 'nyu':
         | 
| 106 | 
            +
                                eval_mask[45:471, 41:601] = 1
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        valid_mask = np.logical_and(valid_mask, eval_mask)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    eval_measures[:9] += torch.tensor(measures).cuda()
         | 
| 113 | 
            +
                    eval_measures[9] += 1
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                eval_measures_cpu = eval_measures.cpu()
         | 
| 116 | 
            +
                cnt = eval_measures_cpu[9].item()
         | 
| 117 | 
            +
                eval_measures_cpu /= cnt
         | 
| 118 | 
            +
                print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
         | 
| 119 | 
            +
                print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
         | 
| 120 | 
            +
                                                                                                'sq_rel', 'log_rms', 'd1', 'd2',
         | 
| 121 | 
            +
                                                                                                'd3'))
         | 
| 122 | 
            +
                for i in range(8):
         | 
| 123 | 
            +
                    print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
         | 
| 124 | 
            +
                print('{:7.4f}'.format(eval_measures_cpu[8]))
         | 
| 125 | 
            +
                return eval_measures_cpu
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def main_worker(args):
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                # CRF model
         | 
| 131 | 
            +
                model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
         | 
| 132 | 
            +
                model.train()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                num_params = sum([np.prod(p.size()) for p in model.parameters()])
         | 
| 135 | 
            +
                print("== Total number of parameters: {}".format(num_params))
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
         | 
| 138 | 
            +
                print("== Total number of learning parameters: {}".format(num_params_update))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                model = torch.nn.DataParallel(model)
         | 
| 141 | 
            +
                model.cuda()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                print("== Model Initialized")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                if args.checkpoint_path != '':
         | 
| 146 | 
            +
                    if os.path.isfile(args.checkpoint_path):
         | 
| 147 | 
            +
                        print("== Loading checkpoint '{}'".format(args.checkpoint_path))
         | 
| 148 | 
            +
                        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
         | 
| 149 | 
            +
                        model.load_state_dict(checkpoint['model'])
         | 
| 150 | 
            +
                        print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
         | 
| 151 | 
            +
                        del checkpoint
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        print("== No checkpoint found at '{}'".format(args.checkpoint_path))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                cudnn.benchmark = True
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                dataloader_eval = NewDataLoader(args, 'online_eval')
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # ===== Evaluation ======
         | 
| 160 | 
            +
                model.eval()
         | 
| 161 | 
            +
                with torch.no_grad():
         | 
| 162 | 
            +
                    eval_measures = eval(model, dataloader_eval, post_process=True)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def main():
         | 
| 166 | 
            +
                torch.cuda.empty_cache()
         | 
| 167 | 
            +
                args.distributed = False
         | 
| 168 | 
            +
                ngpus_per_node = torch.cuda.device_count()
         | 
| 169 | 
            +
                if ngpus_per_node > 1:
         | 
| 170 | 
            +
                    print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
         | 
| 171 | 
            +
                    return -1
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                main_worker(args)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            if __name__ == '__main__':
         | 
| 177 | 
            +
                main()
         | 
    	
        iebins/eval_sun.py
    ADDED
    
    | @@ -0,0 +1,179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import os, sys
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from utils import post_process_depth, flip_lr, compute_errors
         | 
| 10 | 
            +
            from networks.NewCRFDepth import NewCRFDepth
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def convert_arg_line_to_args(arg_line):
         | 
| 14 | 
            +
                for arg in arg_line.split():
         | 
| 15 | 
            +
                    if not arg.strip():
         | 
| 16 | 
            +
                        continue
         | 
| 17 | 
            +
                    yield arg
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            parser = argparse.ArgumentParser(description='IEbins PyTorch implementation.', fromfile_prefix_chars='@')
         | 
| 21 | 
            +
            parser.convert_arg_line_to_args = convert_arg_line_to_args
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            parser.add_argument('--model_name',                type=str,   help='model name', default='iebins')
         | 
| 24 | 
            +
            parser.add_argument('--encoder',                   type=str,   help='type of encoder, base07, large07, tiny07', default='large07')
         | 
| 25 | 
            +
            parser.add_argument('--checkpoint_path',           type=str,   help='path to a checkpoint to load', default='')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Dataset
         | 
| 28 | 
            +
            parser.add_argument('--dataset',                   type=str,   help='dataset to train on, kitti or nyu', default='nyu')
         | 
| 29 | 
            +
            parser.add_argument('--input_height',              type=int,   help='input height', default=480)
         | 
| 30 | 
            +
            parser.add_argument('--input_width',               type=int,   help='input width',  default=640)
         | 
| 31 | 
            +
            parser.add_argument('--max_depth',                 type=float, help='maximum depth in estimation', default=10)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Preprocessing
         | 
| 34 | 
            +
            parser.add_argument('--do_random_rotate',                      help='if set, will perform random rotation for augmentation', action='store_true')
         | 
| 35 | 
            +
            parser.add_argument('--degree',                    type=float, help='random rotation maximum degree', default=2.5)
         | 
| 36 | 
            +
            parser.add_argument('--do_kb_crop',                            help='if set, crop input images as kitti benchmark images', action='store_true')
         | 
| 37 | 
            +
            parser.add_argument('--use_right',                             help='if set, will randomly use right images when train on KITTI', action='store_true')
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Eval
         | 
| 40 | 
            +
            parser.add_argument('--data_path_eval',            type=str,   help='path to the data for evaluation', required=False)
         | 
| 41 | 
            +
            parser.add_argument('--gt_path_eval',              type=str,   help='path to the groundtruth data for evaluation', required=False)
         | 
| 42 | 
            +
            parser.add_argument('--filenames_file_eval',       type=str,   help='path to the filenames text file for evaluation', required=False)
         | 
| 43 | 
            +
            parser.add_argument('--min_depth_eval',            type=float, help='minimum depth for evaluation', default=1e-3)
         | 
| 44 | 
            +
            parser.add_argument('--max_depth_eval',            type=float, help='maximum depth for evaluation', default=80)
         | 
| 45 | 
            +
            parser.add_argument('--eigen_crop',                            help='if set, crops according to Eigen NIPS14', action='store_true')
         | 
| 46 | 
            +
            parser.add_argument('--garg_crop',                             help='if set, crops according to Garg  ECCV16', action='store_true')
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            if sys.argv.__len__() == 2:
         | 
| 50 | 
            +
                arg_filename_with_prefix = '@' + sys.argv[1]
         | 
| 51 | 
            +
                args = parser.parse_args([arg_filename_with_prefix])
         | 
| 52 | 
            +
            else:
         | 
| 53 | 
            +
                args = parser.parse_args()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            if args.dataset == 'nyu':
         | 
| 56 | 
            +
                from dataloaders.dataloader_sun import NewDataLoader
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def eval(model, dataloader_eval, post_process=False):
         | 
| 60 | 
            +
                eval_measures = torch.zeros(10).cuda()
         | 
| 61 | 
            +
                for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
         | 
| 62 | 
            +
                    with torch.no_grad():
         | 
| 63 | 
            +
                        image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
         | 
| 64 | 
            +
                        gt_depth = eval_sample_batched['depth']
         | 
| 65 | 
            +
                        has_valid_depth = eval_sample_batched['has_valid_depth']
         | 
| 66 | 
            +
                        if not has_valid_depth:
         | 
| 67 | 
            +
                            # print('Invalid depth. continue.')
         | 
| 68 | 
            +
                            continue
         | 
| 69 | 
            +
                        _, hh, ww, _ = gt_depth.shape
         | 
| 70 | 
            +
                        pred_depths_r_list, _, _ = model(image)
         | 
| 71 | 
            +
                        if post_process:
         | 
| 72 | 
            +
                            image_flipped = flip_lr(image)
         | 
| 73 | 
            +
                            pred_depths_r_list_flipped, _, _ = model(image_flipped)
         | 
| 74 | 
            +
                            pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
         | 
| 75 | 
            +
                            pred_depth = F.interpolate(pred_depth, [hh, ww], mode="bilinear", align_corners=False)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        pred_depth = pred_depth.cpu().numpy().squeeze()
         | 
| 78 | 
            +
                        gt_depth = gt_depth.cpu().numpy().squeeze()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if args.do_kb_crop:
         | 
| 81 | 
            +
                        height, width = gt_depth.shape
         | 
| 82 | 
            +
                        top_margin = int(height - 352)
         | 
| 83 | 
            +
                        left_margin = int((width - 1216) / 2)
         | 
| 84 | 
            +
                        pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
         | 
| 85 | 
            +
                        pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
         | 
| 86 | 
            +
                        pred_depth = pred_depth_uncropped
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
         | 
| 89 | 
            +
                    pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
         | 
| 90 | 
            +
                    pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
         | 
| 91 | 
            +
                    pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
         | 
| 92 | 
            +
                    pred_depth[pred_depth > 8] = 8
         | 
| 93 | 
            +
                    gt_depth[gt_depth > 8] = 8
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if args.garg_crop or args.eigen_crop:
         | 
| 98 | 
            +
                        gt_height, gt_width = gt_depth.shape
         | 
| 99 | 
            +
                        eval_mask = np.zeros(valid_mask.shape)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        if args.garg_crop:
         | 
| 102 | 
            +
                            eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        elif args.eigen_crop:
         | 
| 105 | 
            +
                            if args.dataset == 'kitti':
         | 
| 106 | 
            +
                                eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 107 | 
            +
                            elif args.dataset == 'nyu':
         | 
| 108 | 
            +
                                eval_mask[45:471, 41:601] = 1
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        valid_mask = np.logical_and(valid_mask, eval_mask)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    eval_measures[:9] += torch.tensor(measures).cuda()
         | 
| 115 | 
            +
                    eval_measures[9] += 1
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                eval_measures_cpu = eval_measures.cpu()
         | 
| 118 | 
            +
                cnt = eval_measures_cpu[9].item()
         | 
| 119 | 
            +
                eval_measures_cpu /= cnt
         | 
| 120 | 
            +
                print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
         | 
| 121 | 
            +
                print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
         | 
| 122 | 
            +
                                                                                                'sq_rel', 'log_rms', 'd1', 'd2',
         | 
| 123 | 
            +
                                                                                                'd3'))
         | 
| 124 | 
            +
                for i in range(8):
         | 
| 125 | 
            +
                    print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
         | 
| 126 | 
            +
                print('{:7.4f}'.format(eval_measures_cpu[8]))
         | 
| 127 | 
            +
                return eval_measures_cpu
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def main_worker(args):
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                # CRF model
         | 
| 133 | 
            +
                model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
         | 
| 134 | 
            +
                model.train()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                num_params = sum([np.prod(p.size()) for p in model.parameters()])
         | 
| 137 | 
            +
                print("== Total number of parameters: {}".format(num_params))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
         | 
| 140 | 
            +
                print("== Total number of learning parameters: {}".format(num_params_update))
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                model = torch.nn.DataParallel(model)
         | 
| 143 | 
            +
                model.cuda()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                print("== Model Initialized")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                if args.checkpoint_path != '':
         | 
| 148 | 
            +
                    if os.path.isfile(args.checkpoint_path):
         | 
| 149 | 
            +
                        print("== Loading checkpoint '{}'".format(args.checkpoint_path))
         | 
| 150 | 
            +
                        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
         | 
| 151 | 
            +
                        model.load_state_dict(checkpoint['model'])
         | 
| 152 | 
            +
                        print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
         | 
| 153 | 
            +
                        del checkpoint
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        print("== No checkpoint found at '{}'".format(args.checkpoint_path))
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                cudnn.benchmark = True
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                dataloader_eval = NewDataLoader(args, 'online_eval')
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # ===== Evaluation ======
         | 
| 162 | 
            +
                model.eval()
         | 
| 163 | 
            +
                with torch.no_grad():
         | 
| 164 | 
            +
                    eval_measures = eval(model, dataloader_eval, post_process=True)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def main():
         | 
| 168 | 
            +
                torch.cuda.empty_cache()
         | 
| 169 | 
            +
                args.distributed = False
         | 
| 170 | 
            +
                ngpus_per_node = torch.cuda.device_count()
         | 
| 171 | 
            +
                if ngpus_per_node > 1:
         | 
| 172 | 
            +
                    print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
         | 
| 173 | 
            +
                    return -1
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
                main_worker(args)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            if __name__ == '__main__':
         | 
| 179 | 
            +
                main()
         | 
    	
        iebins/inference_single_image.py
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os, sys
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from utils import post_process_depth, flip_lr, compute_errors
         | 
| 10 | 
            +
            from networks.NewCRFDepth import NewCRFDepth
         | 
| 11 | 
            +
            from PIL import Image 
         | 
| 12 | 
            +
            from torchvision import transforms
         | 
| 13 | 
            +
            import matplotlib.pyplot as plt
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def convert_arg_line_to_args(arg_line):
         | 
| 17 | 
            +
                for arg in arg_line.split():
         | 
| 18 | 
            +
                    if not arg.strip():
         | 
| 19 | 
            +
                        continue
         | 
| 20 | 
            +
                    yield arg
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
         | 
| 24 | 
            +
            parser.convert_arg_line_to_args = convert_arg_line_to_args
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            parser.add_argument('--model_name',                type=str,   help='model name', default='iebins')
         | 
| 27 | 
            +
            parser.add_argument('--encoder',                   type=str,   help='type of encoder, base07, large07', default='large07')
         | 
| 28 | 
            +
            parser.add_argument('--checkpoint_path',           type=str,   help='path to a checkpoint to load', default='')
         | 
| 29 | 
            +
            parser.add_argument('--dataset',                   type=str,   help='dataset to train on, kitti or nyu', default='nyu')
         | 
| 30 | 
            +
            parser.add_argument('--image_path',                type=str,   help='path to the image for inference', required=False)
         | 
| 31 | 
            +
            parser.add_argument('--max_depth',                 type=float, help='maximum depth in estimation', default=10)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            if sys.argv.__len__() == 2:
         | 
| 35 | 
            +
                arg_filename_with_prefix = '@' + sys.argv[1]
         | 
| 36 | 
            +
                args = parser.parse_args([arg_filename_with_prefix])
         | 
| 37 | 
            +
            else:
         | 
| 38 | 
            +
                args = parser.parse_args()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def inference(model, post_process=False):
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                image = np.asarray(Image.open(args.image_path), dtype=np.float32) / 255.0
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if args.dataset == 'kitti':
         | 
| 46 | 
            +
                    height = image.shape[0]
         | 
| 47 | 
            +
                    width = image.shape[1]
         | 
| 48 | 
            +
                    top_margin = int(height - 352)
         | 
| 49 | 
            +
                    left_margin = int((width - 1216) / 2)
         | 
| 50 | 
            +
                    image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                image = torch.from_numpy(image.transpose((2, 0, 1)))
         | 
| 53 | 
            +
                image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                with torch.no_grad():
         | 
| 56 | 
            +
                    image = torch.autograd.Variable(image.unsqueeze(0).cuda())
         | 
| 57 | 
            +
                   
         | 
| 58 | 
            +
                    pred_depths_r_list, _, _ = model(image)
         | 
| 59 | 
            +
                    if post_process:
         | 
| 60 | 
            +
                        image_flipped = flip_lr(image)
         | 
| 61 | 
            +
                        pred_depths_r_list_flipped, _, _ = model(image_flipped)
         | 
| 62 | 
            +
                        pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    pred_depth = pred_depth.cpu().numpy().squeeze()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    if args.dataset == 'kitti':
         | 
| 67 | 
            +
                        plt.imsave('depth.png', np.log10(pred_depth), cmap='magma')
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        plt.imsave('depth.png', pred_depth, cmap='jet')
         | 
| 70 | 
            +
                        
         | 
| 71 | 
            +
                      
         | 
| 72 | 
            +
            def main_worker(args):
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
         | 
| 75 | 
            +
                model.train()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                num_params = sum([np.prod(p.size()) for p in model.parameters()])
         | 
| 78 | 
            +
                print("== Total number of parameters: {}".format(num_params))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
         | 
| 81 | 
            +
                print("== Total number of learning parameters: {}".format(num_params_update))
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                model = torch.nn.DataParallel(model)
         | 
| 84 | 
            +
                model.cuda()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                print("== Model Initialized")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if args.checkpoint_path != '':
         | 
| 89 | 
            +
                    if os.path.isfile(args.checkpoint_path):
         | 
| 90 | 
            +
                        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
         | 
| 91 | 
            +
                        model.load_state_dict(checkpoint['model'])
         | 
| 92 | 
            +
                        print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
         | 
| 93 | 
            +
                        del checkpoint
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        print("== No checkpoint found at '{}'".format(args.checkpoint_path))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                cudnn.benchmark = True
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # ===== Inference ======
         | 
| 100 | 
            +
                model.eval()
         | 
| 101 | 
            +
                with torch.no_grad():
         | 
| 102 | 
            +
                    inference(model, post_process=True)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def main():
         | 
| 106 | 
            +
                torch.cuda.empty_cache()
         | 
| 107 | 
            +
                args.distributed = False
         | 
| 108 | 
            +
                ngpus_per_node = torch.cuda.device_count()
         | 
| 109 | 
            +
                if ngpus_per_node > 1:
         | 
| 110 | 
            +
                    print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
         | 
| 111 | 
            +
                    return -1
         | 
| 112 | 
            +
                
         | 
| 113 | 
            +
                main_worker(args)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            if __name__ == '__main__':
         | 
| 117 | 
            +
                main()
         | 
    	
        iebins/networks/NewCRFDepth.py
    ADDED
    
    | @@ -0,0 +1,318 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from .swin_transformer import SwinTransformer
         | 
| 6 | 
            +
            from .newcrf_layers import NewCRF
         | 
| 7 | 
            +
            from .uper_crf_head import PSP
         | 
| 8 | 
            +
            from .depth_update  import *
         | 
| 9 | 
            +
            ########################################################################################################################
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class NewCRFDepth(nn.Module):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Depth network based on neural window FC-CRFs architecture.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                def __init__(self, version=None, inv_depth=False, pretrained=None, 
         | 
| 17 | 
            +
                                frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs):
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    self.inv_depth = inv_depth
         | 
| 21 | 
            +
                    self.with_auxiliary_head = False
         | 
| 22 | 
            +
                    self.with_neck = False
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    norm_cfg = dict(type='BN', requires_grad=True)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    window_size = int(version[-2:])
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    if version[:-2] == 'base':
         | 
| 29 | 
            +
                        embed_dim = 128
         | 
| 30 | 
            +
                        depths = [2, 2, 18, 2]
         | 
| 31 | 
            +
                        num_heads = [4, 8, 16, 32]
         | 
| 32 | 
            +
                        in_channels = [128, 256, 512, 1024]
         | 
| 33 | 
            +
                        self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=128)
         | 
| 34 | 
            +
                    elif version[:-2] == 'large':
         | 
| 35 | 
            +
                        embed_dim = 192
         | 
| 36 | 
            +
                        depths = [2, 2, 18, 2]
         | 
| 37 | 
            +
                        num_heads = [6, 12, 24, 48]
         | 
| 38 | 
            +
                        in_channels = [192, 384, 768, 1536]
         | 
| 39 | 
            +
                        self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=192)
         | 
| 40 | 
            +
                    elif version[:-2] == 'tiny':
         | 
| 41 | 
            +
                        embed_dim = 96
         | 
| 42 | 
            +
                        depths = [2, 2, 6, 2]
         | 
| 43 | 
            +
                        num_heads = [3, 6, 12, 24]
         | 
| 44 | 
            +
                        in_channels = [96, 192, 384, 768]
         | 
| 45 | 
            +
                        self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=96)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    backbone_cfg = dict(
         | 
| 48 | 
            +
                        embed_dim=embed_dim,
         | 
| 49 | 
            +
                        depths=depths,
         | 
| 50 | 
            +
                        num_heads=num_heads,
         | 
| 51 | 
            +
                        window_size=window_size,
         | 
| 52 | 
            +
                        ape=False,
         | 
| 53 | 
            +
                        drop_path_rate=0.3,
         | 
| 54 | 
            +
                        patch_norm=True,
         | 
| 55 | 
            +
                        use_checkpoint=False,
         | 
| 56 | 
            +
                        frozen_stages=frozen_stages
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    embed_dim = 512
         | 
| 60 | 
            +
                    decoder_cfg = dict(
         | 
| 61 | 
            +
                        in_channels=in_channels,
         | 
| 62 | 
            +
                        in_index=[0, 1, 2, 3],
         | 
| 63 | 
            +
                        pool_scales=(1, 2, 3, 6),
         | 
| 64 | 
            +
                        channels=embed_dim,
         | 
| 65 | 
            +
                        dropout_ratio=0.0,
         | 
| 66 | 
            +
                        num_classes=32,
         | 
| 67 | 
            +
                        norm_cfg=norm_cfg,
         | 
| 68 | 
            +
                        align_corners=False
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.backbone = SwinTransformer(**backbone_cfg)
         | 
| 72 | 
            +
                    v_dim = decoder_cfg['num_classes']*4
         | 
| 73 | 
            +
                    win = 7
         | 
| 74 | 
            +
                    crf_dims = [128, 256, 512, 1024]
         | 
| 75 | 
            +
                    v_dims = [64, 128, 256, embed_dim]
         | 
| 76 | 
            +
                    self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32)
         | 
| 77 | 
            +
                    self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16)
         | 
| 78 | 
            +
                    self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.decoder = PSP(**decoder_cfg)
         | 
| 81 | 
            +
                    self.disp_head1 = DispHead(input_dim=crf_dims[0])
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.up_mode = 'bilinear'
         | 
| 84 | 
            +
                    if self.up_mode == 'mask':
         | 
| 85 | 
            +
                        self.mask_head = nn.Sequential(
         | 
| 86 | 
            +
                            nn.Conv2d(v_dims[0], 64, 3, padding=1),
         | 
| 87 | 
            +
                            nn.ReLU(inplace=True),
         | 
| 88 | 
            +
                            nn.Conv2d(64, 16*9, 1, padding=0))
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.min_depth = min_depth
         | 
| 91 | 
            +
                    self.max_depth = max_depth
         | 
| 92 | 
            +
                    self.depth_num = 16
         | 
| 93 | 
            +
                    self.hidden_dim = 128
         | 
| 94 | 
            +
                    self.project = Projection(v_dims[0], self.hidden_dim)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.init_weights(pretrained=pretrained)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def init_weights(self, pretrained=None):
         | 
| 99 | 
            +
                    """Initialize the weights in backbone and heads.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    Args:
         | 
| 102 | 
            +
                        pretrained (str, optional): Path to pre-trained weights.
         | 
| 103 | 
            +
                            Defaults to None.
         | 
| 104 | 
            +
                    """
         | 
| 105 | 
            +
                    print(f'== Load encoder backbone from: {pretrained}')
         | 
| 106 | 
            +
                    self.backbone.init_weights(pretrained=pretrained)
         | 
| 107 | 
            +
                    self.decoder.init_weights()
         | 
| 108 | 
            +
                    if self.with_auxiliary_head:
         | 
| 109 | 
            +
                        if isinstance(self.auxiliary_head, nn.ModuleList):
         | 
| 110 | 
            +
                            for aux_head in self.auxiliary_head:
         | 
| 111 | 
            +
                                aux_head.init_weights()
         | 
| 112 | 
            +
                        else:
         | 
| 113 | 
            +
                            self.auxiliary_head.init_weights()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def upsample_mask(self, disp, mask):
         | 
| 116 | 
            +
                    """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
         | 
| 117 | 
            +
                    N, C, H, W = disp.shape
         | 
| 118 | 
            +
                    mask = mask.view(N, 1, 9, 4, 4, H, W)
         | 
| 119 | 
            +
                    mask = torch.softmax(mask, dim=2)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    up_disp = F.unfold(disp, kernel_size=3, padding=1)
         | 
| 122 | 
            +
                    up_disp = up_disp.view(N, C, 9, 1, 1, H, W)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    up_disp = torch.sum(mask * up_disp, dim=2)
         | 
| 125 | 
            +
                    up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
         | 
| 126 | 
            +
                    return up_disp.reshape(N, C, 4*H, 4*W)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def forward(self, imgs, epoch=1, step=100):
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    feats = self.backbone(imgs)
         | 
| 131 | 
            +
                    ppm_out = self.decoder(feats)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    e3 = self.crf3(feats[3], ppm_out)
         | 
| 134 | 
            +
                    e3 = nn.PixelShuffle(2)(e3)
         | 
| 135 | 
            +
                    e2 = self.crf2(feats[2], e3)
         | 
| 136 | 
            +
                    e2 = nn.PixelShuffle(2)(e2)
         | 
| 137 | 
            +
                    e1 = self.crf1(feats[1], e2)
         | 
| 138 | 
            +
                    e1 = nn.PixelShuffle(2)(e1)
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
                    # iterative bins
         | 
| 141 | 
            +
                    if epoch == 0 and step < 80:
         | 
| 142 | 
            +
                        max_tree_depth = 3
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        max_tree_depth = 6
         | 
| 145 | 
            +
                    
         | 
| 146 | 
            +
                    if self.up_mode == 'mask':
         | 
| 147 | 
            +
                        mask = self.mask_head(e1)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    b, c, h, w = e1.shape
         | 
| 150 | 
            +
                    device = e1.device
         | 
| 151 | 
            +
                           
         | 
| 152 | 
            +
                    depth = torch.zeros([b, 1, h, w]).to(device)
         | 
| 153 | 
            +
                    context = feats[0]
         | 
| 154 | 
            +
                    gru_hidden = torch.tanh(self.project(e1))
         | 
| 155 | 
            +
                    pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = self.update(depth, context, gru_hidden, max_tree_depth, self.depth_num, self.min_depth, self.max_depth)
         | 
| 156 | 
            +
                    
         | 
| 157 | 
            +
                    if self.up_mode == 'mask':
         | 
| 158 | 
            +
                        for i in range(len(pred_depths_r_list)):
         | 
| 159 | 
            +
                            pred_depths_r_list[i] = self.upsample_mask(pred_depths_r_list[i], mask)  
         | 
| 160 | 
            +
                        for i in range(len(pred_depths_c_list)):
         | 
| 161 | 
            +
                            pred_depths_c_list[i] = self.upsample_mask(pred_depths_c_list[i], mask.detach())
         | 
| 162 | 
            +
                        for i in range(len(uncertainty_maps_list)):
         | 
| 163 | 
            +
                            uncertainty_maps_list[i] = self.upsample_mask(uncertainty_maps_list[i], mask.detach())                   
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        for i in range(len(pred_depths_r_list)):
         | 
| 166 | 
            +
                            pred_depths_r_list[i] = upsample(pred_depths_r_list[i], scale_factor=4)
         | 
| 167 | 
            +
                        for i in range(len(pred_depths_c_list)):
         | 
| 168 | 
            +
                            pred_depths_c_list[i] = upsample(pred_depths_c_list[i], scale_factor=4) 
         | 
| 169 | 
            +
                        for i in range(len(uncertainty_maps_list)):
         | 
| 170 | 
            +
                            uncertainty_maps_list[i] = upsample(uncertainty_maps_list[i], scale_factor=4) 
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            class DispHead(nn.Module):
         | 
| 175 | 
            +
                def __init__(self, input_dim=100):
         | 
| 176 | 
            +
                    super(DispHead, self).__init__()
         | 
| 177 | 
            +
                    # self.norm1 = nn.BatchNorm2d(input_dim)
         | 
| 178 | 
            +
                    self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
         | 
| 179 | 
            +
                    # self.relu = nn.ReLU(inplace=True)
         | 
| 180 | 
            +
                    self.sigmoid = nn.Sigmoid()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def forward(self, x, scale):
         | 
| 183 | 
            +
                    # x = self.relu(self.norm1(x))
         | 
| 184 | 
            +
                    x = self.sigmoid(self.conv1(x))
         | 
| 185 | 
            +
                    if scale > 1:
         | 
| 186 | 
            +
                        x = upsample(x, scale_factor=scale)
         | 
| 187 | 
            +
                    return x
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            class BasicUpdateBlockDepth(nn.Module):
         | 
| 190 | 
            +
                def __init__(self, hidden_dim=128, context_dim=192):
         | 
| 191 | 
            +
                    super(BasicUpdateBlockDepth, self).__init__()
         | 
| 192 | 
            +
                            
         | 
| 193 | 
            +
                    self.encoder = ProjectionInputDepth(hidden_dim=hidden_dim, out_chs=hidden_dim * 2)
         | 
| 194 | 
            +
                    self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=self.encoder.out_chs+context_dim)
         | 
| 195 | 
            +
                    self.p_head = PHead(hidden_dim, hidden_dim)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def forward(self, depth, context, gru_hidden, seq_len, depth_num, min_depth, max_depth):
         | 
| 198 | 
            +
             
         | 
| 199 | 
            +
                    pred_depths_r_list = []
         | 
| 200 | 
            +
                    pred_depths_c_list = []
         | 
| 201 | 
            +
                    uncertainty_maps_list = []
         | 
| 202 | 
            +
                  
         | 
| 203 | 
            +
                    b, _, h, w = depth.shape
         | 
| 204 | 
            +
                    depth_range = max_depth - min_depth
         | 
| 205 | 
            +
                    interval = depth_range / depth_num
         | 
| 206 | 
            +
                    interval = interval * torch.ones_like(depth)
         | 
| 207 | 
            +
                    interval = interval.repeat(1, depth_num, 1, 1)
         | 
| 208 | 
            +
                    interval = torch.cat([torch.ones_like(depth) * min_depth, interval], 1)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    bin_edges = torch.cumsum(interval, 1)
         | 
| 211 | 
            +
                    current_depths = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
         | 
| 212 | 
            +
                    index_iter = 0
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    for i in range(seq_len):
         | 
| 215 | 
            +
                        input_features = self.encoder(current_depths.detach())
         | 
| 216 | 
            +
                        input_c = torch.cat([input_features, context], dim=1)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        gru_hidden = self.gru(gru_hidden, input_c)
         | 
| 219 | 
            +
                        pred_prob = self.p_head(gru_hidden)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                        depth_r = (pred_prob * current_depths.detach()).sum(1, keepdim=True)
         | 
| 222 | 
            +
                        pred_depths_r_list.append(depth_r)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                        uncertainty_map = torch.sqrt((pred_prob * ((current_depths.detach() - depth_r.repeat(1, depth_num, 1, 1))**2)).sum(1, keepdim=True))
         | 
| 225 | 
            +
                        uncertainty_maps_list.append(uncertainty_map)
         | 
| 226 | 
            +
                    
         | 
| 227 | 
            +
                        index_iter = index_iter + 1
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        pred_label = get_label(torch.squeeze(depth_r, 1), bin_edges, depth_num).unsqueeze(1)
         | 
| 230 | 
            +
                        depth_c = torch.gather(current_depths.detach(), 1, pred_label.detach())
         | 
| 231 | 
            +
                        pred_depths_c_list.append(depth_c)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                        label_target_bin_left = pred_label
         | 
| 234 | 
            +
                        target_bin_left = torch.gather(bin_edges, 1, label_target_bin_left)
         | 
| 235 | 
            +
                        label_target_bin_right = (pred_label.float() + 1).long()
         | 
| 236 | 
            +
                        target_bin_right = torch.gather(bin_edges, 1, label_target_bin_right)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                        bin_edges, current_depths = update_sample(bin_edges, target_bin_left, target_bin_right, depth_r.detach(), pred_label.detach(), depth_num, min_depth, max_depth, uncertainty_map)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
         | 
| 241 | 
            +
             | 
| 242 | 
            +
            class PHead(nn.Module):
         | 
| 243 | 
            +
                def __init__(self, input_dim=128, hidden_dim=128):
         | 
| 244 | 
            +
                    super(PHead, self).__init__()
         | 
| 245 | 
            +
                    self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
         | 
| 246 | 
            +
                    self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
         | 
| 247 | 
            +
                
         | 
| 248 | 
            +
                def forward(self, x):
         | 
| 249 | 
            +
                    out = torch.softmax(self.conv2(F.relu(self.conv1(x))), 1)
         | 
| 250 | 
            +
                    return out
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            class SepConvGRU(nn.Module):
         | 
| 253 | 
            +
                def __init__(self, hidden_dim=128, input_dim=128+192):
         | 
| 254 | 
            +
                    super(SepConvGRU, self).__init__()
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
         | 
| 257 | 
            +
                    self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
         | 
| 258 | 
            +
                    self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
         | 
| 259 | 
            +
                    self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
         | 
| 260 | 
            +
                    self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
         | 
| 261 | 
            +
                    self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def forward(self, h, x):
         | 
| 264 | 
            +
                    # horizontal
         | 
| 265 | 
            +
                    hx = torch.cat([h, x], dim=1)
         | 
| 266 | 
            +
                    z = torch.sigmoid(self.convz1(hx))
         | 
| 267 | 
            +
                    r = torch.sigmoid(self.convr1(hx))
         | 
| 268 | 
            +
                    q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 
         | 
| 269 | 
            +
                    
         | 
| 270 | 
            +
                    h = (1-z) * h + z * q
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    # vertical
         | 
| 273 | 
            +
                    hx = torch.cat([h, x], dim=1)
         | 
| 274 | 
            +
                    z = torch.sigmoid(self.convz2(hx))
         | 
| 275 | 
            +
                    r = torch.sigmoid(self.convr2(hx))
         | 
| 276 | 
            +
                    q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
         | 
| 277 | 
            +
                    h = (1-z) * h + z * q
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    return h
         | 
| 280 | 
            +
             | 
| 281 | 
            +
            class ProjectionInputDepth(nn.Module):
         | 
| 282 | 
            +
                def __init__(self, hidden_dim, out_chs):
         | 
| 283 | 
            +
                    super().__init__()
         | 
| 284 | 
            +
                    self.out_chs = out_chs 
         | 
| 285 | 
            +
                    self.convd1 = nn.Conv2d(16, hidden_dim, 7, padding=3)
         | 
| 286 | 
            +
                    self.convd2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
         | 
| 287 | 
            +
                    self.convd3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
         | 
| 288 | 
            +
                    self.convd4 = nn.Conv2d(hidden_dim, out_chs, 3, padding=1)
         | 
| 289 | 
            +
                    
         | 
| 290 | 
            +
                def forward(self, depth):
         | 
| 291 | 
            +
                    d = F.relu(self.convd1(depth))
         | 
| 292 | 
            +
                    d = F.relu(self.convd2(d))
         | 
| 293 | 
            +
                    d = F.relu(self.convd3(d))
         | 
| 294 | 
            +
                    d = F.relu(self.convd4(d))
         | 
| 295 | 
            +
                            
         | 
| 296 | 
            +
                    return d
         | 
| 297 | 
            +
             | 
| 298 | 
            +
            class Projection(nn.Module):
         | 
| 299 | 
            +
                def __init__(self, in_chs, out_chs):
         | 
| 300 | 
            +
                    super().__init__()
         | 
| 301 | 
            +
                    self.conv = nn.Conv2d(in_chs, out_chs, 3, padding=1)
         | 
| 302 | 
            +
                    
         | 
| 303 | 
            +
                def forward(self, x):
         | 
| 304 | 
            +
                    out = self.conv(x)
         | 
| 305 | 
            +
                            
         | 
| 306 | 
            +
                    return out
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            def upsample(x, scale_factor=2, mode="bilinear", align_corners=False):
         | 
| 309 | 
            +
                """Upsample input tensor by a factor of 2
         | 
| 310 | 
            +
                """
         | 
| 311 | 
            +
                return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
            def upsample1(x, scale_factor=2, mode="bilinear"):
         | 
| 314 | 
            +
                """Upsample input tensor by a factor of 2
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
                return F.interpolate(x, scale_factor=scale_factor, mode=mode)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
    	
        iebins/networks/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        iebins/networks/depth_update.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range):
         | 
| 6 | 
            +
                
         | 
| 7 | 
            +
                with torch.no_grad():    
         | 
| 8 | 
            +
                    b, _, h, w = bin_edges.shape
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    mode = 'direct'
         | 
| 11 | 
            +
                    if mode == 'direct':
         | 
| 12 | 
            +
                        depth_range = uncertainty_range
         | 
| 13 | 
            +
                        depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth)
         | 
| 14 | 
            +
                    else:
         | 
| 15 | 
            +
                        depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs()
         | 
| 16 | 
            +
                        depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    interval = depth_range / depth_num
         | 
| 19 | 
            +
                    interval = interval.repeat(1, depth_num, 1, 1)
         | 
| 20 | 
            +
                    interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth)
         | 
| 23 | 
            +
                    curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                return bin_edges.detach(), curr_depth.detach()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_label(gt_depth_img, bin_edges, depth_num):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                with torch.no_grad():
         | 
| 30 | 
            +
                    gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device)
         | 
| 31 | 
            +
                    for i in range(depth_num):
         | 
| 32 | 
            +
                        bin_mask = torch.ge(gt_depth_img, bin_edges[:, i])
         | 
| 33 | 
            +
                        bin_mask = torch.logical_and(bin_mask, 
         | 
| 34 | 
            +
                            torch.lt(gt_depth_img, bin_edges[:, i + 1]))
         | 
| 35 | 
            +
                        gt_label[bin_mask] = i
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    return gt_label
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
    	
        iebins/networks/newcrf_layers.py
    ADDED
    
    | @@ -0,0 +1,433 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import torch.utils.checkpoint as checkpoint
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class Mlp(nn.Module):
         | 
| 10 | 
            +
                """ Multilayer perceptron."""
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         | 
| 13 | 
            +
                    super().__init__()
         | 
| 14 | 
            +
                    out_features = out_features or in_features
         | 
| 15 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 16 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 17 | 
            +
                    self.act = act_layer()
         | 
| 18 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 19 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def forward(self, x):
         | 
| 22 | 
            +
                    x = self.fc1(x)
         | 
| 23 | 
            +
                    x = self.act(x)
         | 
| 24 | 
            +
                    x = self.drop(x)
         | 
| 25 | 
            +
                    x = self.fc2(x)
         | 
| 26 | 
            +
                    x = self.drop(x)
         | 
| 27 | 
            +
                    return x
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def window_partition(x, window_size):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                Args:
         | 
| 33 | 
            +
                    x: (B, H, W, C)
         | 
| 34 | 
            +
                    window_size (int): window size
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Returns:
         | 
| 37 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                B, H, W, C = x.shape
         | 
| 40 | 
            +
                x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
         | 
| 41 | 
            +
                windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
         | 
| 42 | 
            +
                return windows
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def window_reverse(windows, window_size, H, W):
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                Args:
         | 
| 48 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 49 | 
            +
                    window_size (int): Window size
         | 
| 50 | 
            +
                    H (int): Height of image
         | 
| 51 | 
            +
                    W (int): Width of image
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                Returns:
         | 
| 54 | 
            +
                    x: (B, H, W, C)
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                B = int(windows.shape[0] / (H * W / window_size / window_size))
         | 
| 57 | 
            +
                x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
         | 
| 58 | 
            +
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
         | 
| 59 | 
            +
                return x
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            class WindowAttention(nn.Module):
         | 
| 63 | 
            +
                """ Window based multi-head self attention (W-MSA) module with relative position bias.
         | 
| 64 | 
            +
                It supports both of shifted and non-shifted window.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    dim (int): Number of input channels.
         | 
| 68 | 
            +
                    window_size (tuple[int]): The height and width of the window.
         | 
| 69 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 70 | 
            +
                    qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
         | 
| 71 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
         | 
| 72 | 
            +
                    attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
         | 
| 73 | 
            +
                    proj_drop (float, optional): Dropout ratio of output. Default: 0.0
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    super().__init__()
         | 
| 79 | 
            +
                    self.dim = dim
         | 
| 80 | 
            +
                    self.window_size = window_size  # Wh, Ww
         | 
| 81 | 
            +
                    self.num_heads = num_heads
         | 
| 82 | 
            +
                    head_dim = dim // num_heads
         | 
| 83 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # define a parameter table of relative position bias
         | 
| 86 | 
            +
                    self.relative_position_bias_table = nn.Parameter(
         | 
| 87 | 
            +
                        torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # get pair-wise relative position index for each token inside the window
         | 
| 90 | 
            +
                    coords_h = torch.arange(self.window_size[0])
         | 
| 91 | 
            +
                    coords_w = torch.arange(self.window_size[1])
         | 
| 92 | 
            +
                    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         | 
| 93 | 
            +
                    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
         | 
| 94 | 
            +
                    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
         | 
| 95 | 
            +
                    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
         | 
| 96 | 
            +
                    relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         | 
| 97 | 
            +
                    relative_coords[:, :, 1] += self.window_size[1] - 1
         | 
| 98 | 
            +
                    relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
         | 
| 99 | 
            +
                    relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
         | 
| 100 | 
            +
                    self.register_buffer("relative_position_index", relative_position_index)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
         | 
| 103 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 104 | 
            +
                    self.proj = nn.Linear(v_dim, v_dim)
         | 
| 105 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    trunc_normal_(self.relative_position_bias_table, std=.02)
         | 
| 108 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def forward(self, x, v, mask=None):
         | 
| 111 | 
            +
                    """ Forward function.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    Args:
         | 
| 114 | 
            +
                        x: input features with shape of (num_windows*B, N, C)
         | 
| 115 | 
            +
                        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    B_, N, C = x.shape
         | 
| 118 | 
            +
                    qk = self.qk(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 119 | 
            +
                    q, k = qk[0], qk[1]  # make torchscript happy (cannot use tensor as tuple)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    q = q * self.scale
         | 
| 122 | 
            +
                    attn = (q @ k.transpose(-2, -1))
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         | 
| 125 | 
            +
                        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
         | 
| 126 | 
            +
                    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 127 | 
            +
                    attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    if mask is not None:
         | 
| 130 | 
            +
                        nW = mask.shape[0]
         | 
| 131 | 
            +
                        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
         | 
| 132 | 
            +
                        attn = attn.view(-1, self.num_heads, N, N)
         | 
| 133 | 
            +
                        attn = self.softmax(attn)
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        attn = self.softmax(attn)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    # assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0"
         | 
| 140 | 
            +
                    # repeat_num = self.dim // v.shape[-1]
         | 
| 141 | 
            +
                    # v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    assert self.dim == v.shape[-1], "self.dim != v.shape[-1]"
         | 
| 144 | 
            +
                    v = v.view(B_, N, self.num_heads, -1).transpose(1, 2)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
         | 
| 147 | 
            +
                    x = self.proj(x)
         | 
| 148 | 
            +
                    x = self.proj_drop(x)
         | 
| 149 | 
            +
                    return x
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            class CRFBlock(nn.Module):
         | 
| 153 | 
            +
                """ CRF Block.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                Args:
         | 
| 156 | 
            +
                    dim (int): Number of input channels.
         | 
| 157 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 158 | 
            +
                    window_size (int): Window size.
         | 
| 159 | 
            +
                    shift_size (int): Shift size for SW-MSA.
         | 
| 160 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 161 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 162 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 163 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 164 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 165 | 
            +
                    drop_path (float, optional): Stochastic depth rate. Default: 0.0
         | 
| 166 | 
            +
                    act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
         | 
| 167 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 168 | 
            +
                """
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def __init__(self, dim, num_heads, v_dim, window_size=7, shift_size=0,
         | 
| 171 | 
            +
                             mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
         | 
| 172 | 
            +
                             act_layer=nn.GELU, norm_layer=nn.LayerNorm):
         | 
| 173 | 
            +
                    super().__init__()
         | 
| 174 | 
            +
                    self.dim = dim
         | 
| 175 | 
            +
                    self.num_heads = num_heads
         | 
| 176 | 
            +
                    self.v_dim = v_dim
         | 
| 177 | 
            +
                    self.window_size = window_size
         | 
| 178 | 
            +
                    self.shift_size = shift_size
         | 
| 179 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 180 | 
            +
                    assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 183 | 
            +
                    self.attn = WindowAttention(
         | 
| 184 | 
            +
                        dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim,
         | 
| 185 | 
            +
                        qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 188 | 
            +
                    self.norm2 = norm_layer(v_dim)
         | 
| 189 | 
            +
                    mlp_hidden_dim = int(v_dim * mlp_ratio)
         | 
| 190 | 
            +
                    self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    self.H = None
         | 
| 193 | 
            +
                    self.W = None
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def forward(self, x, v, mask_matrix):
         | 
| 196 | 
            +
                    """ Forward function.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    Args:
         | 
| 199 | 
            +
                        x: Input feature, tensor size (B, H*W, C).
         | 
| 200 | 
            +
                        H, W: Spatial resolution of the input feature.
         | 
| 201 | 
            +
                        mask_matrix: Attention mask for cyclic shift.
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    B, L, C = x.shape
         | 
| 204 | 
            +
                    H, W = self.H, self.W
         | 
| 205 | 
            +
                    assert L == H * W, "input feature has wrong size"
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    shortcut = x
         | 
| 208 | 
            +
                    x = self.norm1(x)
         | 
| 209 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # pad feature maps to multiples of window size
         | 
| 212 | 
            +
                    pad_l = pad_t = 0
         | 
| 213 | 
            +
                    pad_r = (self.window_size - W % self.window_size) % self.window_size
         | 
| 214 | 
            +
                    pad_b = (self.window_size - H % self.window_size) % self.window_size
         | 
| 215 | 
            +
                    x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
         | 
| 216 | 
            +
                    v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
         | 
| 217 | 
            +
                    _, Hp, Wp, _ = x.shape
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # cyclic shift
         | 
| 220 | 
            +
                    if self.shift_size > 0:
         | 
| 221 | 
            +
                        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         | 
| 222 | 
            +
                        shifted_v = torch.roll(v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         | 
| 223 | 
            +
                        attn_mask = mask_matrix
         | 
| 224 | 
            +
                    else:
         | 
| 225 | 
            +
                        shifted_x = x
         | 
| 226 | 
            +
                        shifted_v = v
         | 
| 227 | 
            +
                        attn_mask = None
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # partition windows
         | 
| 230 | 
            +
                    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
         | 
| 231 | 
            +
                    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
         | 
| 232 | 
            +
                    v_windows = window_partition(shifted_v, self.window_size)  # nW*B, window_size, window_size, C
         | 
| 233 | 
            +
                    v_windows = v_windows.view(-1, self.window_size * self.window_size, v_windows.shape[-1])  # nW*B, window_size*window_size, C
         | 
| 234 | 
            +
                    
         | 
| 235 | 
            +
                    # W-MSA/SW-MSA
         | 
| 236 | 
            +
                    attn_windows = self.attn(x_windows, v_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    # merge windows
         | 
| 239 | 
            +
                    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim)
         | 
| 240 | 
            +
                    shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # reverse cyclic shift
         | 
| 243 | 
            +
                    if self.shift_size > 0:
         | 
| 244 | 
            +
                        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
         | 
| 245 | 
            +
                    else:
         | 
| 246 | 
            +
                        x = shifted_x
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    if pad_r > 0 or pad_b > 0:
         | 
| 249 | 
            +
                        x = x[:, :H, :W, :].contiguous()
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    x = x.view(B, H * W, self.v_dim)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    # FFN
         | 
| 254 | 
            +
                    x = shortcut + self.drop_path(x)
         | 
| 255 | 
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    return x
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            class BasicCRFLayer(nn.Module):
         | 
| 261 | 
            +
                """ A basic NeWCRFs layer for one stage.
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                Args:
         | 
| 264 | 
            +
                    dim (int): Number of feature channels
         | 
| 265 | 
            +
                    depth (int): Depths of this stage.
         | 
| 266 | 
            +
                    num_heads (int): Number of attention head.
         | 
| 267 | 
            +
                    window_size (int): Local window size. Default: 7.
         | 
| 268 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
         | 
| 269 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 270 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 271 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 272 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 273 | 
            +
                    drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
         | 
| 274 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
         | 
| 275 | 
            +
                    downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
         | 
| 276 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         | 
| 277 | 
            +
                """
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def __init__(self,
         | 
| 280 | 
            +
                             dim,
         | 
| 281 | 
            +
                             depth,
         | 
| 282 | 
            +
                             num_heads,
         | 
| 283 | 
            +
                             v_dim,
         | 
| 284 | 
            +
                             window_size=7,
         | 
| 285 | 
            +
                             mlp_ratio=4.,
         | 
| 286 | 
            +
                             qkv_bias=True,
         | 
| 287 | 
            +
                             qk_scale=None,
         | 
| 288 | 
            +
                             drop=0.,
         | 
| 289 | 
            +
                             attn_drop=0.,
         | 
| 290 | 
            +
                             drop_path=0.,
         | 
| 291 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 292 | 
            +
                             downsample=None,
         | 
| 293 | 
            +
                             use_checkpoint=False):
         | 
| 294 | 
            +
                    super().__init__()
         | 
| 295 | 
            +
                    self.window_size = window_size
         | 
| 296 | 
            +
                    self.shift_size = window_size // 2
         | 
| 297 | 
            +
                    self.depth = depth
         | 
| 298 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # build blocks
         | 
| 301 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 302 | 
            +
                        CRFBlock(
         | 
| 303 | 
            +
                            dim=dim,
         | 
| 304 | 
            +
                            num_heads=num_heads,
         | 
| 305 | 
            +
                            v_dim=v_dim,
         | 
| 306 | 
            +
                            window_size=window_size,
         | 
| 307 | 
            +
                            shift_size=0 if (i % 2 == 0) else window_size // 2,
         | 
| 308 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 309 | 
            +
                            qkv_bias=qkv_bias,
         | 
| 310 | 
            +
                            qk_scale=qk_scale,
         | 
| 311 | 
            +
                            drop=drop,
         | 
| 312 | 
            +
                            attn_drop=attn_drop,
         | 
| 313 | 
            +
                            drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
         | 
| 314 | 
            +
                            norm_layer=norm_layer)
         | 
| 315 | 
            +
                        for i in range(depth)])
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    # patch merging layer
         | 
| 318 | 
            +
                    if downsample is not None:
         | 
| 319 | 
            +
                        self.downsample = downsample(dim=dim, norm_layer=norm_layer)
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        self.downsample = None
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def forward(self, x, v, H, W):
         | 
| 324 | 
            +
                    """ Forward function.
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    Args:
         | 
| 327 | 
            +
                        x: Input feature, tensor size (B, H*W, C).
         | 
| 328 | 
            +
                        H, W: Spatial resolution of the input feature.
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # calculate attention mask for SW-MSA
         | 
| 332 | 
            +
                    Hp = int(np.ceil(H / self.window_size)) * self.window_size
         | 
| 333 | 
            +
                    Wp = int(np.ceil(W / self.window_size)) * self.window_size
         | 
| 334 | 
            +
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         | 
| 335 | 
            +
                    h_slices = (slice(0, -self.window_size),
         | 
| 336 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 337 | 
            +
                                slice(-self.shift_size, None))
         | 
| 338 | 
            +
                    w_slices = (slice(0, -self.window_size),
         | 
| 339 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 340 | 
            +
                                slice(-self.shift_size, None))
         | 
| 341 | 
            +
                    cnt = 0
         | 
| 342 | 
            +
                    for h in h_slices:
         | 
| 343 | 
            +
                        for w in w_slices:
         | 
| 344 | 
            +
                            img_mask[:, h, w, :] = cnt
         | 
| 345 | 
            +
                            cnt += 1
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         | 
| 348 | 
            +
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         | 
| 349 | 
            +
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         | 
| 350 | 
            +
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    for blk in self.blocks:
         | 
| 353 | 
            +
                        blk.H, blk.W = H, W
         | 
| 354 | 
            +
                        if self.use_checkpoint:
         | 
| 355 | 
            +
                            x = checkpoint.checkpoint(blk, x, attn_mask)
         | 
| 356 | 
            +
                        else:
         | 
| 357 | 
            +
                            x = blk(x, v, attn_mask)
         | 
| 358 | 
            +
                    if self.downsample is not None:
         | 
| 359 | 
            +
                        x_down = self.downsample(x, H, W)
         | 
| 360 | 
            +
                        Wh, Ww = (H + 1) // 2, (W + 1) // 2
         | 
| 361 | 
            +
                        return x, H, W, x_down, Wh, Ww
         | 
| 362 | 
            +
                    else:
         | 
| 363 | 
            +
                        return x, H, W, x, H, W
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            class NewCRF(nn.Module):
         | 
| 367 | 
            +
                def __init__(self,
         | 
| 368 | 
            +
                             input_dim=96,
         | 
| 369 | 
            +
                             embed_dim=96,
         | 
| 370 | 
            +
                             v_dim=64,
         | 
| 371 | 
            +
                             window_size=7,
         | 
| 372 | 
            +
                             num_heads=4,
         | 
| 373 | 
            +
                             depth=2,
         | 
| 374 | 
            +
                             patch_size=4,
         | 
| 375 | 
            +
                             in_chans=3,
         | 
| 376 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 377 | 
            +
                             patch_norm=True):
         | 
| 378 | 
            +
                    super().__init__()
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    self.embed_dim = embed_dim
         | 
| 381 | 
            +
                    self.patch_norm = patch_norm
         | 
| 382 | 
            +
                    
         | 
| 383 | 
            +
                    if input_dim != embed_dim:
         | 
| 384 | 
            +
                        self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
         | 
| 385 | 
            +
                    else:
         | 
| 386 | 
            +
                        self.proj_x = None
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    if v_dim != embed_dim:
         | 
| 389 | 
            +
                        self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
         | 
| 390 | 
            +
                    elif embed_dim % v_dim == 0:
         | 
| 391 | 
            +
                        self.proj_v = None
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    # For now, v_dim need to be equal to embed_dim, because the output of window-attn is the input of shift-window-attn
         | 
| 394 | 
            +
                    v_dim = embed_dim
         | 
| 395 | 
            +
                    assert v_dim == embed_dim
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    self.crf_layer = BasicCRFLayer(
         | 
| 398 | 
            +
                            dim=embed_dim,
         | 
| 399 | 
            +
                            depth=depth,
         | 
| 400 | 
            +
                            num_heads=num_heads,
         | 
| 401 | 
            +
                            v_dim=v_dim,
         | 
| 402 | 
            +
                            window_size=window_size,
         | 
| 403 | 
            +
                            mlp_ratio=4.,
         | 
| 404 | 
            +
                            qkv_bias=True,
         | 
| 405 | 
            +
                            qk_scale=None,
         | 
| 406 | 
            +
                            drop=0.,
         | 
| 407 | 
            +
                            attn_drop=0.,
         | 
| 408 | 
            +
                            drop_path=0.,
         | 
| 409 | 
            +
                            norm_layer=norm_layer,
         | 
| 410 | 
            +
                            downsample=None,
         | 
| 411 | 
            +
                            use_checkpoint=False)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    layer = norm_layer(embed_dim)
         | 
| 414 | 
            +
                    layer_name = 'norm_crf'
         | 
| 415 | 
            +
                    self.add_module(layer_name, layer)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
             | 
| 418 | 
            +
                def forward(self, x, v):
         | 
| 419 | 
            +
                    if self.proj_x is not None:
         | 
| 420 | 
            +
                        x = self.proj_x(x)
         | 
| 421 | 
            +
                    if self.proj_v is not None:
         | 
| 422 | 
            +
                        v = self.proj_v(v)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    Wh, Ww = x.size(2), x.size(3)
         | 
| 425 | 
            +
                    x = x.flatten(2).transpose(1, 2)
         | 
| 426 | 
            +
                    v = v.transpose(1, 2).transpose(2, 3)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww)
         | 
| 429 | 
            +
                    norm_layer = getattr(self, f'norm_crf')
         | 
| 430 | 
            +
                    x_out = norm_layer(x_out)
         | 
| 431 | 
            +
                    out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    return out
         | 
    	
        iebins/networks/newcrf_utils.py
    ADDED
    
    | @@ -0,0 +1,264 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import warnings
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import os.path as osp
         | 
| 4 | 
            +
            import pkgutil
         | 
| 5 | 
            +
            import warnings
         | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
            from importlib import import_module
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torchvision
         | 
| 11 | 
            +
            import torch.nn as nn
         | 
| 12 | 
            +
            from torch.utils import model_zoo
         | 
| 13 | 
            +
            from torch.nn import functional as F
         | 
| 14 | 
            +
            from torch.nn.parallel import DataParallel, DistributedDataParallel
         | 
| 15 | 
            +
            from torch import distributed as dist
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            TORCH_VERSION = torch.__version__
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def resize(input,
         | 
| 21 | 
            +
                       size=None,
         | 
| 22 | 
            +
                       scale_factor=None,
         | 
| 23 | 
            +
                       mode='nearest',
         | 
| 24 | 
            +
                       align_corners=None,
         | 
| 25 | 
            +
                       warning=True):
         | 
| 26 | 
            +
                if warning:
         | 
| 27 | 
            +
                    if size is not None and align_corners:
         | 
| 28 | 
            +
                        input_h, input_w = tuple(int(x) for x in input.shape[2:])
         | 
| 29 | 
            +
                        output_h, output_w = tuple(int(x) for x in size)
         | 
| 30 | 
            +
                        if output_h > input_h or output_w > output_h:
         | 
| 31 | 
            +
                            if ((output_h > 1 and output_w > 1 and input_h > 1
         | 
| 32 | 
            +
                                 and input_w > 1) and (output_h - 1) % (input_h - 1)
         | 
| 33 | 
            +
                                    and (output_w - 1) % (input_w - 1)):
         | 
| 34 | 
            +
                                warnings.warn(
         | 
| 35 | 
            +
                                    f'When align_corners={align_corners}, '
         | 
| 36 | 
            +
                                    'the output would more aligned if '
         | 
| 37 | 
            +
                                    f'input size {(input_h, input_w)} is `x+1` and '
         | 
| 38 | 
            +
                                    f'out size {(output_h, output_w)} is `nx+1`')
         | 
| 39 | 
            +
                if isinstance(size, torch.Size):
         | 
| 40 | 
            +
                    size = tuple(int(x) for x in size)
         | 
| 41 | 
            +
                return F.interpolate(input, size, scale_factor, mode, align_corners)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def normal_init(module, mean=0, std=1, bias=0):
         | 
| 45 | 
            +
                if hasattr(module, 'weight') and module.weight is not None:
         | 
| 46 | 
            +
                    nn.init.normal_(module.weight, mean, std)
         | 
| 47 | 
            +
                if hasattr(module, 'bias') and module.bias is not None:
         | 
| 48 | 
            +
                    nn.init.constant_(module.bias, bias)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def is_module_wrapper(module):
         | 
| 52 | 
            +
                module_wrappers = (DataParallel, DistributedDataParallel)
         | 
| 53 | 
            +
                return isinstance(module, module_wrappers)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def get_dist_info():
         | 
| 57 | 
            +
                if TORCH_VERSION < '1.0':
         | 
| 58 | 
            +
                    initialized = dist._initialized
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    if dist.is_available():
         | 
| 61 | 
            +
                        initialized = dist.is_initialized()
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        initialized = False
         | 
| 64 | 
            +
                if initialized:
         | 
| 65 | 
            +
                    rank = dist.get_rank()
         | 
| 66 | 
            +
                    world_size = dist.get_world_size()
         | 
| 67 | 
            +
                else:
         | 
| 68 | 
            +
                    rank = 0
         | 
| 69 | 
            +
                    world_size = 1
         | 
| 70 | 
            +
                return rank, world_size
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def load_state_dict(module, state_dict, strict=False, logger=None):
         | 
| 74 | 
            +
                """Load state_dict to a module.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                This method is modified from :meth:`torch.nn.Module.load_state_dict`.
         | 
| 77 | 
            +
                Default value for ``strict`` is set to ``False`` and the message for
         | 
| 78 | 
            +
                param mismatch will be shown even if strict is False.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                Args:
         | 
| 81 | 
            +
                    module (Module): Module that receives the state_dict.
         | 
| 82 | 
            +
                    state_dict (OrderedDict): Weights.
         | 
| 83 | 
            +
                    strict (bool): whether to strictly enforce that the keys
         | 
| 84 | 
            +
                        in :attr:`state_dict` match the keys returned by this module's
         | 
| 85 | 
            +
                        :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
         | 
| 86 | 
            +
                    logger (:obj:`logging.Logger`, optional): Logger to log the error
         | 
| 87 | 
            +
                        message. If not specified, print function will be used.
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                unexpected_keys = []
         | 
| 90 | 
            +
                all_missing_keys = []
         | 
| 91 | 
            +
                err_msg = []
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                metadata = getattr(state_dict, '_metadata', None)
         | 
| 94 | 
            +
                state_dict = state_dict.copy()
         | 
| 95 | 
            +
                if metadata is not None:
         | 
| 96 | 
            +
                    state_dict._metadata = metadata
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # use _load_from_state_dict to enable checkpoint version control
         | 
| 99 | 
            +
                def load(module, prefix=''):
         | 
| 100 | 
            +
                    # recursively check parallel module in case that the model has a
         | 
| 101 | 
            +
                    # complicated structure, e.g., nn.Module(nn.Module(DDP))
         | 
| 102 | 
            +
                    if is_module_wrapper(module):
         | 
| 103 | 
            +
                        module = module.module
         | 
| 104 | 
            +
                    local_metadata = {} if metadata is None else metadata.get(
         | 
| 105 | 
            +
                        prefix[:-1], {})
         | 
| 106 | 
            +
                    module._load_from_state_dict(state_dict, prefix, local_metadata, True,
         | 
| 107 | 
            +
                                                 all_missing_keys, unexpected_keys,
         | 
| 108 | 
            +
                                                 err_msg)
         | 
| 109 | 
            +
                    for name, child in module._modules.items():
         | 
| 110 | 
            +
                        if child is not None:
         | 
| 111 | 
            +
                            load(child, prefix + name + '.')
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                load(module)
         | 
| 114 | 
            +
                load = None  # break load->load reference cycle
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # ignore "num_batches_tracked" of BN layers
         | 
| 117 | 
            +
                missing_keys = [
         | 
| 118 | 
            +
                    key for key in all_missing_keys if 'num_batches_tracked' not in key
         | 
| 119 | 
            +
                ]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                if unexpected_keys:
         | 
| 122 | 
            +
                    err_msg.append('unexpected key in source '
         | 
| 123 | 
            +
                                   f'state_dict: {", ".join(unexpected_keys)}\n')
         | 
| 124 | 
            +
                if missing_keys:
         | 
| 125 | 
            +
                    err_msg.append(
         | 
| 126 | 
            +
                        f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                rank, _ = get_dist_info()
         | 
| 129 | 
            +
                if len(err_msg) > 0 and rank == 0:
         | 
| 130 | 
            +
                    err_msg.insert(
         | 
| 131 | 
            +
                        0, 'The model and loaded state dict do not match exactly\n')
         | 
| 132 | 
            +
                    err_msg = '\n'.join(err_msg)
         | 
| 133 | 
            +
                    if strict:
         | 
| 134 | 
            +
                        raise RuntimeError(err_msg)
         | 
| 135 | 
            +
                    elif logger is not None:
         | 
| 136 | 
            +
                        logger.warning(err_msg)
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        print(err_msg)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def load_url_dist(url, model_dir=None):
         | 
| 142 | 
            +
                """In distributed setting, this function only download checkpoint at local
         | 
| 143 | 
            +
                rank 0."""
         | 
| 144 | 
            +
                rank, world_size = get_dist_info()
         | 
| 145 | 
            +
                rank = int(os.environ.get('LOCAL_RANK', rank))
         | 
| 146 | 
            +
                if rank == 0:
         | 
| 147 | 
            +
                    checkpoint = model_zoo.load_url(url, model_dir=model_dir)
         | 
| 148 | 
            +
                if world_size > 1:
         | 
| 149 | 
            +
                    torch.distributed.barrier()
         | 
| 150 | 
            +
                    if rank > 0:
         | 
| 151 | 
            +
                        checkpoint = model_zoo.load_url(url, model_dir=model_dir)
         | 
| 152 | 
            +
                return checkpoint
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def get_torchvision_models():
         | 
| 156 | 
            +
                model_urls = dict()
         | 
| 157 | 
            +
                for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
         | 
| 158 | 
            +
                    if ispkg:
         | 
| 159 | 
            +
                        continue
         | 
| 160 | 
            +
                    _zoo = import_module(f'torchvision.models.{name}')
         | 
| 161 | 
            +
                    if hasattr(_zoo, 'model_urls'):
         | 
| 162 | 
            +
                        _urls = getattr(_zoo, 'model_urls')
         | 
| 163 | 
            +
                        model_urls.update(_urls)
         | 
| 164 | 
            +
                return model_urls
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def _load_checkpoint(filename, map_location=None):
         | 
| 168 | 
            +
                """Load checkpoint from somewhere (modelzoo, file, url).
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                Args:
         | 
| 171 | 
            +
                    filename (str): Accept local filepath, URL, ``torchvision://xxx``,
         | 
| 172 | 
            +
                        ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
         | 
| 173 | 
            +
                        details.
         | 
| 174 | 
            +
                    map_location (str | None): Same as :func:`torch.load`. Default: None.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                Returns:
         | 
| 177 | 
            +
                    dict | OrderedDict: The loaded checkpoint. It can be either an
         | 
| 178 | 
            +
                        OrderedDict storing model weights or a dict containing other
         | 
| 179 | 
            +
                        information, which depends on the checkpoint.
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
                if filename.startswith('modelzoo://'):
         | 
| 182 | 
            +
                    warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
         | 
| 183 | 
            +
                                  'use "torchvision://" instead')
         | 
| 184 | 
            +
                    model_urls = get_torchvision_models()
         | 
| 185 | 
            +
                    model_name = filename[11:]
         | 
| 186 | 
            +
                    checkpoint = load_url_dist(model_urls[model_name])
         | 
| 187 | 
            +
                else:
         | 
| 188 | 
            +
                    if not osp.isfile(filename):
         | 
| 189 | 
            +
                        raise IOError(f'{filename} is not a checkpoint file')
         | 
| 190 | 
            +
                    checkpoint = torch.load(filename, map_location=map_location)
         | 
| 191 | 
            +
                return checkpoint
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            def load_checkpoint(model,
         | 
| 195 | 
            +
                                filename,
         | 
| 196 | 
            +
                                map_location='cpu',
         | 
| 197 | 
            +
                                strict=False,
         | 
| 198 | 
            +
                                logger=None):
         | 
| 199 | 
            +
                """Load checkpoint from a file or URI.
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                Args:
         | 
| 202 | 
            +
                    model (Module): Module to load checkpoint.
         | 
| 203 | 
            +
                    filename (str): Accept local filepath, URL, ``torchvision://xxx``,
         | 
| 204 | 
            +
                        ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
         | 
| 205 | 
            +
                        details.
         | 
| 206 | 
            +
                    map_location (str): Same as :func:`torch.load`.
         | 
| 207 | 
            +
                    strict (bool): Whether to allow different params for the model and
         | 
| 208 | 
            +
                        checkpoint.
         | 
| 209 | 
            +
                    logger (:mod:`logging.Logger` or None): The logger for error message.
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                Returns:
         | 
| 212 | 
            +
                    dict or OrderedDict: The loaded checkpoint.
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
                checkpoint = _load_checkpoint(filename, map_location)
         | 
| 215 | 
            +
                # OrderedDict is a subclass of dict
         | 
| 216 | 
            +
                if not isinstance(checkpoint, dict):
         | 
| 217 | 
            +
                    raise RuntimeError(
         | 
| 218 | 
            +
                        f'No state_dict found in checkpoint file {filename}')
         | 
| 219 | 
            +
                # get state_dict from checkpoint
         | 
| 220 | 
            +
                if 'state_dict' in checkpoint:
         | 
| 221 | 
            +
                    state_dict = checkpoint['state_dict']
         | 
| 222 | 
            +
                elif 'model' in checkpoint:
         | 
| 223 | 
            +
                    state_dict = checkpoint['model']
         | 
| 224 | 
            +
                else:
         | 
| 225 | 
            +
                    state_dict = checkpoint
         | 
| 226 | 
            +
                # strip prefix of state_dict
         | 
| 227 | 
            +
                if list(state_dict.keys())[0].startswith('module.'):
         | 
| 228 | 
            +
                    state_dict = {k[7:]: v for k, v in state_dict.items()}
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                # for MoBY, load model of online branch
         | 
| 231 | 
            +
                if sorted(list(state_dict.keys()))[0].startswith('encoder'):
         | 
| 232 | 
            +
                    state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # reshape absolute position embedding
         | 
| 235 | 
            +
                if state_dict.get('absolute_pos_embed') is not None:
         | 
| 236 | 
            +
                    absolute_pos_embed = state_dict['absolute_pos_embed']
         | 
| 237 | 
            +
                    N1, L, C1 = absolute_pos_embed.size()
         | 
| 238 | 
            +
                    N2, C2, H, W = model.absolute_pos_embed.size()
         | 
| 239 | 
            +
                    if N1 != N2 or C1 != C2 or L != H*W:
         | 
| 240 | 
            +
                        logger.warning("Error in loading absolute_pos_embed, pass")
         | 
| 241 | 
            +
                    else:
         | 
| 242 | 
            +
                        state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                # interpolate position bias table if needed
         | 
| 245 | 
            +
                relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
         | 
| 246 | 
            +
                for table_key in relative_position_bias_table_keys:
         | 
| 247 | 
            +
                    table_pretrained = state_dict[table_key]
         | 
| 248 | 
            +
                    table_current = model.state_dict()[table_key]
         | 
| 249 | 
            +
                    L1, nH1 = table_pretrained.size()
         | 
| 250 | 
            +
                    L2, nH2 = table_current.size()
         | 
| 251 | 
            +
                    if nH1 != nH2:
         | 
| 252 | 
            +
                        logger.warning(f"Error in loading {table_key}, pass")
         | 
| 253 | 
            +
                    else:
         | 
| 254 | 
            +
                        if L1 != L2:
         | 
| 255 | 
            +
                            S1 = int(L1 ** 0.5)
         | 
| 256 | 
            +
                            S2 = int(L2 ** 0.5)
         | 
| 257 | 
            +
                            table_pretrained_resized = F.interpolate(
         | 
| 258 | 
            +
                                 table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
         | 
| 259 | 
            +
                                 size=(S2, S2), mode='bicubic')
         | 
| 260 | 
            +
                            state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # load state_dict
         | 
| 263 | 
            +
                load_state_dict(model, state_dict, strict, logger)
         | 
| 264 | 
            +
                return checkpoint
         | 
    	
        iebins/networks/resize.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def resize(input,
         | 
| 9 | 
            +
                       size=None,
         | 
| 10 | 
            +
                       scale_factor=None,
         | 
| 11 | 
            +
                       mode='nearest',
         | 
| 12 | 
            +
                       align_corners=None,
         | 
| 13 | 
            +
                       warning=False):
         | 
| 14 | 
            +
                if warning:
         | 
| 15 | 
            +
                    if size is not None and align_corners:
         | 
| 16 | 
            +
                        input_h, input_w = tuple(int(x) for x in input.shape[2:])
         | 
| 17 | 
            +
                        output_h, output_w = tuple(int(x) for x in size)
         | 
| 18 | 
            +
                        if output_h > input_h or output_w > output_h:
         | 
| 19 | 
            +
                            if ((output_h > 1 and output_w > 1 and input_h > 1
         | 
| 20 | 
            +
                                 and input_w > 1) and (output_h - 1) % (input_h - 1)
         | 
| 21 | 
            +
                                    and (output_w - 1) % (input_w - 1)):
         | 
| 22 | 
            +
                                warnings.warn(
         | 
| 23 | 
            +
                                    f'When align_corners={align_corners}, '
         | 
| 24 | 
            +
                                    'the output would more aligned if '
         | 
| 25 | 
            +
                                    f'input size {(input_h, input_w)} is `x+1` and '
         | 
| 26 | 
            +
                                    f'out size {(output_h, output_w)} is `nx+1`')
         | 
| 27 | 
            +
                return F.interpolate(input, size, scale_factor, mode, align_corners)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Upsample(nn.Module):
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __init__(self,
         | 
| 33 | 
            +
                             size=None,
         | 
| 34 | 
            +
                             scale_factor=None,
         | 
| 35 | 
            +
                             mode='nearest',
         | 
| 36 | 
            +
                             align_corners=None):
         | 
| 37 | 
            +
                    super(Upsample, self).__init__()
         | 
| 38 | 
            +
                    self.size = size
         | 
| 39 | 
            +
                    if isinstance(scale_factor, tuple):
         | 
| 40 | 
            +
                        self.scale_factor = tuple(float(factor) for factor in scale_factor)
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        self.scale_factor = float(scale_factor) if scale_factor else None
         | 
| 43 | 
            +
                    self.mode = mode
         | 
| 44 | 
            +
                    self.align_corners = align_corners
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def forward(self, x):
         | 
| 47 | 
            +
                    if not self.size:
         | 
| 48 | 
            +
                        size = [int(t * self.scale_factor) for t in x.shape[-2:]]
         | 
| 49 | 
            +
                    else:
         | 
| 50 | 
            +
                        size = self.size
         | 
| 51 | 
            +
                    return resize(x, size, None, self.mode, self.align_corners)
         | 
    	
        iebins/networks/swin_transformer.py
    ADDED
    
    | @@ -0,0 +1,620 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import torch.utils.checkpoint as checkpoint
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .newcrf_utils import load_checkpoint
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class Mlp(nn.Module):
         | 
| 12 | 
            +
                """ Multilayer perceptron."""
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         | 
| 15 | 
            +
                    super().__init__()
         | 
| 16 | 
            +
                    out_features = out_features or in_features
         | 
| 17 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 18 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 19 | 
            +
                    self.act = act_layer()
         | 
| 20 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 21 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def forward(self, x):
         | 
| 24 | 
            +
                    x = self.fc1(x)
         | 
| 25 | 
            +
                    x = self.act(x)
         | 
| 26 | 
            +
                    x = self.drop(x)
         | 
| 27 | 
            +
                    x = self.fc2(x)
         | 
| 28 | 
            +
                    x = self.drop(x)
         | 
| 29 | 
            +
                    return x
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def window_partition(x, window_size):
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                Args:
         | 
| 35 | 
            +
                    x: (B, H, W, C)
         | 
| 36 | 
            +
                    window_size (int): window size
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                Returns:
         | 
| 39 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                B, H, W, C = x.shape
         | 
| 42 | 
            +
                x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
         | 
| 43 | 
            +
                windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
         | 
| 44 | 
            +
                return windows
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def window_reverse(windows, window_size, H, W):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Args:
         | 
| 50 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 51 | 
            +
                    window_size (int): Window size
         | 
| 52 | 
            +
                    H (int): Height of image
         | 
| 53 | 
            +
                    W (int): Width of image
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                Returns:
         | 
| 56 | 
            +
                    x: (B, H, W, C)
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                B = int(windows.shape[0] / (H * W / window_size / window_size))
         | 
| 59 | 
            +
                x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
         | 
| 60 | 
            +
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
         | 
| 61 | 
            +
                return x
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class WindowAttention(nn.Module):
         | 
| 65 | 
            +
                """ Window based multi-head self attention (W-MSA) module with relative position bias.
         | 
| 66 | 
            +
                It supports both of shifted and non-shifted window.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                Args:
         | 
| 69 | 
            +
                    dim (int): Number of input channels.
         | 
| 70 | 
            +
                    window_size (tuple[int]): The height and width of the window.
         | 
| 71 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 72 | 
            +
                    qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
         | 
| 73 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
         | 
| 74 | 
            +
                    attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
         | 
| 75 | 
            +
                    proj_drop (float, optional): Dropout ratio of output. Default: 0.0
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    super().__init__()
         | 
| 81 | 
            +
                    self.dim = dim
         | 
| 82 | 
            +
                    self.window_size = window_size  # Wh, Ww
         | 
| 83 | 
            +
                    self.num_heads = num_heads
         | 
| 84 | 
            +
                    head_dim = dim // num_heads
         | 
| 85 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    # define a parameter table of relative position bias
         | 
| 88 | 
            +
                    self.relative_position_bias_table = nn.Parameter(
         | 
| 89 | 
            +
                        torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    # get pair-wise relative position index for each token inside the window
         | 
| 92 | 
            +
                    coords_h = torch.arange(self.window_size[0])
         | 
| 93 | 
            +
                    coords_w = torch.arange(self.window_size[1])
         | 
| 94 | 
            +
                    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         | 
| 95 | 
            +
                    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
         | 
| 96 | 
            +
                    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
         | 
| 97 | 
            +
                    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
         | 
| 98 | 
            +
                    relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         | 
| 99 | 
            +
                    relative_coords[:, :, 1] += self.window_size[1] - 1
         | 
| 100 | 
            +
                    relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
         | 
| 101 | 
            +
                    relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
         | 
| 102 | 
            +
                    self.register_buffer("relative_position_index", relative_position_index)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 105 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 106 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 107 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    trunc_normal_(self.relative_position_bias_table, std=.02)
         | 
| 110 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, x, mask=None):
         | 
| 113 | 
            +
                    """ Forward function.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    Args:
         | 
| 116 | 
            +
                        x: input features with shape of (num_windows*B, N, C)
         | 
| 117 | 
            +
                        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
                    B_, N, C = x.shape
         | 
| 120 | 
            +
                    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 121 | 
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    q = q * self.scale
         | 
| 124 | 
            +
                    attn = (q @ k.transpose(-2, -1))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         | 
| 127 | 
            +
                        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
         | 
| 128 | 
            +
                    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 129 | 
            +
                    attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if mask is not None:
         | 
| 132 | 
            +
                        nW = mask.shape[0]
         | 
| 133 | 
            +
                        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
         | 
| 134 | 
            +
                        attn = attn.view(-1, self.num_heads, N, N)
         | 
| 135 | 
            +
                        attn = self.softmax(attn)
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        attn = self.softmax(attn)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
         | 
| 142 | 
            +
                    x = self.proj(x)
         | 
| 143 | 
            +
                    x = self.proj_drop(x)
         | 
| 144 | 
            +
                    return x
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            class SwinTransformerBlock(nn.Module):
         | 
| 148 | 
            +
                """ Swin Transformer Block.
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                Args:
         | 
| 151 | 
            +
                    dim (int): Number of input channels.
         | 
| 152 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 153 | 
            +
                    window_size (int): Window size.
         | 
| 154 | 
            +
                    shift_size (int): Shift size for SW-MSA.
         | 
| 155 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 156 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 157 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 158 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 159 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 160 | 
            +
                    drop_path (float, optional): Stochastic depth rate. Default: 0.0
         | 
| 161 | 
            +
                    act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
         | 
| 162 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def __init__(self, dim, num_heads, window_size=7, shift_size=0,
         | 
| 166 | 
            +
                             mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
         | 
| 167 | 
            +
                             act_layer=nn.GELU, norm_layer=nn.LayerNorm):
         | 
| 168 | 
            +
                    super().__init__()
         | 
| 169 | 
            +
                    self.dim = dim
         | 
| 170 | 
            +
                    self.num_heads = num_heads
         | 
| 171 | 
            +
                    self.window_size = window_size
         | 
| 172 | 
            +
                    self.shift_size = shift_size
         | 
| 173 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 174 | 
            +
                    assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 177 | 
            +
                    self.attn = WindowAttention(
         | 
| 178 | 
            +
                        dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
         | 
| 179 | 
            +
                        qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 182 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 183 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 184 | 
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    self.H = None
         | 
| 187 | 
            +
                    self.W = None
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def forward(self, x, mask_matrix):
         | 
| 190 | 
            +
                    """ Forward function.
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    Args:
         | 
| 193 | 
            +
                        x: Input feature, tensor size (B, H*W, C).
         | 
| 194 | 
            +
                        H, W: Spatial resolution of the input feature.
         | 
| 195 | 
            +
                        mask_matrix: Attention mask for cyclic shift.
         | 
| 196 | 
            +
                    """
         | 
| 197 | 
            +
                    B, L, C = x.shape
         | 
| 198 | 
            +
                    H, W = self.H, self.W
         | 
| 199 | 
            +
                    assert L == H * W, "input feature has wrong size"
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    shortcut = x
         | 
| 202 | 
            +
                    x = self.norm1(x)
         | 
| 203 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # pad feature maps to multiples of window size
         | 
| 206 | 
            +
                    pad_l = pad_t = 0
         | 
| 207 | 
            +
                    pad_r = (self.window_size - W % self.window_size) % self.window_size
         | 
| 208 | 
            +
                    pad_b = (self.window_size - H % self.window_size) % self.window_size
         | 
| 209 | 
            +
                    x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
         | 
| 210 | 
            +
                    _, Hp, Wp, _ = x.shape
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    # cyclic shift
         | 
| 213 | 
            +
                    if self.shift_size > 0:
         | 
| 214 | 
            +
                        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         | 
| 215 | 
            +
                        attn_mask = mask_matrix
         | 
| 216 | 
            +
                    else:
         | 
| 217 | 
            +
                        shifted_x = x
         | 
| 218 | 
            +
                        attn_mask = None
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # partition windows
         | 
| 221 | 
            +
                    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
         | 
| 222 | 
            +
                    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # W-MSA/SW-MSA
         | 
| 225 | 
            +
                    attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # merge windows
         | 
| 228 | 
            +
                    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
         | 
| 229 | 
            +
                    shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    # reverse cyclic shift
         | 
| 232 | 
            +
                    if self.shift_size > 0:
         | 
| 233 | 
            +
                        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
         | 
| 234 | 
            +
                    else:
         | 
| 235 | 
            +
                        x = shifted_x
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if pad_r > 0 or pad_b > 0:
         | 
| 238 | 
            +
                        x = x[:, :H, :W, :].contiguous()
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    x = x.view(B, H * W, C)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # FFN
         | 
| 243 | 
            +
                    x = shortcut + self.drop_path(x)
         | 
| 244 | 
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    return x
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            class PatchMerging(nn.Module):
         | 
| 250 | 
            +
                """ Patch Merging Layer
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                Args:
         | 
| 253 | 
            +
                    dim (int): Number of input channels.
         | 
| 254 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                def __init__(self, dim, norm_layer=nn.LayerNorm):
         | 
| 257 | 
            +
                    super().__init__()
         | 
| 258 | 
            +
                    self.dim = dim
         | 
| 259 | 
            +
                    self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
         | 
| 260 | 
            +
                    self.norm = norm_layer(4 * dim)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                def forward(self, x, H, W):
         | 
| 263 | 
            +
                    """ Forward function.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    Args:
         | 
| 266 | 
            +
                        x: Input feature, tensor size (B, H*W, C).
         | 
| 267 | 
            +
                        H, W: Spatial resolution of the input feature.
         | 
| 268 | 
            +
                    """
         | 
| 269 | 
            +
                    B, L, C = x.shape
         | 
| 270 | 
            +
                    assert L == H * W, "input feature has wrong size"
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    # padding
         | 
| 275 | 
            +
                    pad_input = (H % 2 == 1) or (W % 2 == 1)
         | 
| 276 | 
            +
                    if pad_input:
         | 
| 277 | 
            +
                        x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
         | 
| 280 | 
            +
                    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
         | 
| 281 | 
            +
                    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
         | 
| 282 | 
            +
                    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
         | 
| 283 | 
            +
                    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
         | 
| 284 | 
            +
                    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    x = self.norm(x)
         | 
| 287 | 
            +
                    x = self.reduction(x)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    return x
         | 
| 290 | 
            +
             | 
| 291 | 
            +
             | 
| 292 | 
            +
            class BasicLayer(nn.Module):
         | 
| 293 | 
            +
                """ A basic Swin Transformer layer for one stage.
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                Args:
         | 
| 296 | 
            +
                    dim (int): Number of feature channels
         | 
| 297 | 
            +
                    depth (int): Depths of this stage.
         | 
| 298 | 
            +
                    num_heads (int): Number of attention head.
         | 
| 299 | 
            +
                    window_size (int): Local window size. Default: 7.
         | 
| 300 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
         | 
| 301 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 302 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 303 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 304 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 305 | 
            +
                    drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
         | 
| 306 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
         | 
| 307 | 
            +
                    downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
         | 
| 308 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         | 
| 309 | 
            +
                """
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def __init__(self,
         | 
| 312 | 
            +
                             dim,
         | 
| 313 | 
            +
                             depth,
         | 
| 314 | 
            +
                             num_heads,
         | 
| 315 | 
            +
                             window_size=7,
         | 
| 316 | 
            +
                             mlp_ratio=4.,
         | 
| 317 | 
            +
                             qkv_bias=True,
         | 
| 318 | 
            +
                             qk_scale=None,
         | 
| 319 | 
            +
                             drop=0.,
         | 
| 320 | 
            +
                             attn_drop=0.,
         | 
| 321 | 
            +
                             drop_path=0.,
         | 
| 322 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 323 | 
            +
                             downsample=None,
         | 
| 324 | 
            +
                             use_checkpoint=False):
         | 
| 325 | 
            +
                    super().__init__()
         | 
| 326 | 
            +
                    self.window_size = window_size
         | 
| 327 | 
            +
                    self.shift_size = window_size // 2
         | 
| 328 | 
            +
                    self.depth = depth
         | 
| 329 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # build blocks
         | 
| 332 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 333 | 
            +
                        SwinTransformerBlock(
         | 
| 334 | 
            +
                            dim=dim,
         | 
| 335 | 
            +
                            num_heads=num_heads,
         | 
| 336 | 
            +
                            window_size=window_size,
         | 
| 337 | 
            +
                            shift_size=0 if (i % 2 == 0) else window_size // 2,
         | 
| 338 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 339 | 
            +
                            qkv_bias=qkv_bias,
         | 
| 340 | 
            +
                            qk_scale=qk_scale,
         | 
| 341 | 
            +
                            drop=drop,
         | 
| 342 | 
            +
                            attn_drop=attn_drop,
         | 
| 343 | 
            +
                            drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
         | 
| 344 | 
            +
                            norm_layer=norm_layer)
         | 
| 345 | 
            +
                        for i in range(depth)])
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    # patch merging layer
         | 
| 348 | 
            +
                    if downsample is not None:
         | 
| 349 | 
            +
                        self.downsample = downsample(dim=dim, norm_layer=norm_layer)
         | 
| 350 | 
            +
                    else:
         | 
| 351 | 
            +
                        self.downsample = None
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                def forward(self, x, H, W):
         | 
| 354 | 
            +
                    """ Forward function.
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    Args:
         | 
| 357 | 
            +
                        x: Input feature, tensor size (B, H*W, C).
         | 
| 358 | 
            +
                        H, W: Spatial resolution of the input feature.
         | 
| 359 | 
            +
                    """
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    # calculate attention mask for SW-MSA
         | 
| 362 | 
            +
                    Hp = int(np.ceil(H / self.window_size)) * self.window_size
         | 
| 363 | 
            +
                    Wp = int(np.ceil(W / self.window_size)) * self.window_size
         | 
| 364 | 
            +
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         | 
| 365 | 
            +
                    h_slices = (slice(0, -self.window_size),
         | 
| 366 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 367 | 
            +
                                slice(-self.shift_size, None))
         | 
| 368 | 
            +
                    w_slices = (slice(0, -self.window_size),
         | 
| 369 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 370 | 
            +
                                slice(-self.shift_size, None))
         | 
| 371 | 
            +
                    cnt = 0
         | 
| 372 | 
            +
                    for h in h_slices:
         | 
| 373 | 
            +
                        for w in w_slices:
         | 
| 374 | 
            +
                            img_mask[:, h, w, :] = cnt
         | 
| 375 | 
            +
                            cnt += 1
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         | 
| 378 | 
            +
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         | 
| 379 | 
            +
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         | 
| 380 | 
            +
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    for blk in self.blocks:
         | 
| 383 | 
            +
                        blk.H, blk.W = H, W
         | 
| 384 | 
            +
                        if self.use_checkpoint:
         | 
| 385 | 
            +
                            x = checkpoint.checkpoint(blk, x, attn_mask)
         | 
| 386 | 
            +
                        else:
         | 
| 387 | 
            +
                            x = blk(x, attn_mask)
         | 
| 388 | 
            +
                    if self.downsample is not None:
         | 
| 389 | 
            +
                        x_down = self.downsample(x, H, W)
         | 
| 390 | 
            +
                        Wh, Ww = (H + 1) // 2, (W + 1) // 2
         | 
| 391 | 
            +
                        return x, H, W, x_down, Wh, Ww
         | 
| 392 | 
            +
                    else:
         | 
| 393 | 
            +
                        return x, H, W, x, H, W
         | 
| 394 | 
            +
             | 
| 395 | 
            +
             | 
| 396 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 397 | 
            +
                """ Image to Patch Embedding
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                Args:
         | 
| 400 | 
            +
                    patch_size (int): Patch token size. Default: 4.
         | 
| 401 | 
            +
                    in_chans (int): Number of input image channels. Default: 3.
         | 
| 402 | 
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         | 
| 403 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: None
         | 
| 404 | 
            +
                """
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
         | 
| 407 | 
            +
                    super().__init__()
         | 
| 408 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 409 | 
            +
                    self.patch_size = patch_size
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    self.in_chans = in_chans
         | 
| 412 | 
            +
                    self.embed_dim = embed_dim
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
         | 
| 415 | 
            +
                    if norm_layer is not None:
         | 
| 416 | 
            +
                        self.norm = norm_layer(embed_dim)
         | 
| 417 | 
            +
                    else:
         | 
| 418 | 
            +
                        self.norm = None
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                def forward(self, x):
         | 
| 421 | 
            +
                    """Forward function."""
         | 
| 422 | 
            +
                    # padding
         | 
| 423 | 
            +
                    _, _, H, W = x.size()
         | 
| 424 | 
            +
                    if W % self.patch_size[1] != 0:
         | 
| 425 | 
            +
                        x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
         | 
| 426 | 
            +
                    if H % self.patch_size[0] != 0:
         | 
| 427 | 
            +
                        x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    x = self.proj(x)  # B C Wh Ww
         | 
| 430 | 
            +
                    if self.norm is not None:
         | 
| 431 | 
            +
                        Wh, Ww = x.size(2), x.size(3)
         | 
| 432 | 
            +
                        x = x.flatten(2).transpose(1, 2)
         | 
| 433 | 
            +
                        x = self.norm(x)
         | 
| 434 | 
            +
                        x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    return x
         | 
| 437 | 
            +
             | 
| 438 | 
            +
             | 
| 439 | 
            +
            class SwinTransformer(nn.Module):
         | 
| 440 | 
            +
                """ Swin Transformer backbone.
         | 
| 441 | 
            +
                    A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
         | 
| 442 | 
            +
                      https://arxiv.org/pdf/2103.14030
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                Args:
         | 
| 445 | 
            +
                    pretrain_img_size (int): Input image size for training the pretrained model,
         | 
| 446 | 
            +
                        used in absolute postion embedding. Default 224.
         | 
| 447 | 
            +
                    patch_size (int | tuple(int)): Patch size. Default: 4.
         | 
| 448 | 
            +
                    in_chans (int): Number of input image channels. Default: 3.
         | 
| 449 | 
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         | 
| 450 | 
            +
                    depths (tuple[int]): Depths of each Swin Transformer stage.
         | 
| 451 | 
            +
                    num_heads (tuple[int]): Number of attention head of each stage.
         | 
| 452 | 
            +
                    window_size (int): Window size. Default: 7.
         | 
| 453 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
         | 
| 454 | 
            +
                    qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
         | 
| 455 | 
            +
                    qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 456 | 
            +
                    drop_rate (float): Dropout rate.
         | 
| 457 | 
            +
                    attn_drop_rate (float): Attention dropout rate. Default: 0.
         | 
| 458 | 
            +
                    drop_path_rate (float): Stochastic depth rate. Default: 0.2.
         | 
| 459 | 
            +
                    norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
         | 
| 460 | 
            +
                    ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
         | 
| 461 | 
            +
                    patch_norm (bool): If True, add normalization after patch embedding. Default: True.
         | 
| 462 | 
            +
                    out_indices (Sequence[int]): Output from which stages.
         | 
| 463 | 
            +
                    frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
         | 
| 464 | 
            +
                        -1 means not freezing any parameters.
         | 
| 465 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         | 
| 466 | 
            +
                """
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                def __init__(self,
         | 
| 469 | 
            +
                             pretrain_img_size=224,
         | 
| 470 | 
            +
                             patch_size=4,
         | 
| 471 | 
            +
                             in_chans=3,
         | 
| 472 | 
            +
                             embed_dim=96,
         | 
| 473 | 
            +
                             depths=[2, 2, 6, 2],
         | 
| 474 | 
            +
                             num_heads=[3, 6, 12, 24],
         | 
| 475 | 
            +
                             window_size=7,
         | 
| 476 | 
            +
                             mlp_ratio=4.,
         | 
| 477 | 
            +
                             qkv_bias=True,
         | 
| 478 | 
            +
                             qk_scale=None,
         | 
| 479 | 
            +
                             drop_rate=0.,
         | 
| 480 | 
            +
                             attn_drop_rate=0.,
         | 
| 481 | 
            +
                             drop_path_rate=0.2,
         | 
| 482 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 483 | 
            +
                             ape=False,
         | 
| 484 | 
            +
                             patch_norm=True,
         | 
| 485 | 
            +
                             out_indices=(0, 1, 2, 3),
         | 
| 486 | 
            +
                             frozen_stages=-1,
         | 
| 487 | 
            +
                             use_checkpoint=False):
         | 
| 488 | 
            +
                    super().__init__()
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    self.pretrain_img_size = pretrain_img_size
         | 
| 491 | 
            +
                    self.num_layers = len(depths)
         | 
| 492 | 
            +
                    self.embed_dim = embed_dim
         | 
| 493 | 
            +
                    self.ape = ape
         | 
| 494 | 
            +
                    self.patch_norm = patch_norm
         | 
| 495 | 
            +
                    self.out_indices = out_indices
         | 
| 496 | 
            +
                    self.frozen_stages = frozen_stages
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    # split image into non-overlapping patches
         | 
| 499 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 500 | 
            +
                        patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
         | 
| 501 | 
            +
                        norm_layer=norm_layer if self.patch_norm else None)
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    # absolute position embedding
         | 
| 504 | 
            +
                    if self.ape:
         | 
| 505 | 
            +
                        pretrain_img_size = to_2tuple(pretrain_img_size)
         | 
| 506 | 
            +
                        patch_size = to_2tuple(patch_size)
         | 
| 507 | 
            +
                        patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                        self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
         | 
| 510 | 
            +
                        trunc_normal_(self.absolute_pos_embed, std=.02)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    # stochastic depth
         | 
| 515 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    # build layers
         | 
| 518 | 
            +
                    self.layers = nn.ModuleList()
         | 
| 519 | 
            +
                    for i_layer in range(self.num_layers):
         | 
| 520 | 
            +
                        layer = BasicLayer(
         | 
| 521 | 
            +
                            dim=int(embed_dim * 2 ** i_layer),
         | 
| 522 | 
            +
                            depth=depths[i_layer],
         | 
| 523 | 
            +
                            num_heads=num_heads[i_layer],
         | 
| 524 | 
            +
                            window_size=window_size,
         | 
| 525 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 526 | 
            +
                            qkv_bias=qkv_bias,
         | 
| 527 | 
            +
                            qk_scale=qk_scale,
         | 
| 528 | 
            +
                            drop=drop_rate,
         | 
| 529 | 
            +
                            attn_drop=attn_drop_rate,
         | 
| 530 | 
            +
                            drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
         | 
| 531 | 
            +
                            norm_layer=norm_layer,
         | 
| 532 | 
            +
                            downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
         | 
| 533 | 
            +
                            use_checkpoint=use_checkpoint)
         | 
| 534 | 
            +
                        self.layers.append(layer)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
         | 
| 537 | 
            +
                    self.num_features = num_features
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # add a norm layer for each output
         | 
| 540 | 
            +
                    for i_layer in out_indices:
         | 
| 541 | 
            +
                        layer = norm_layer(num_features[i_layer])
         | 
| 542 | 
            +
                        layer_name = f'norm{i_layer}'
         | 
| 543 | 
            +
                        self.add_module(layer_name, layer)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    self._freeze_stages()
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                def _freeze_stages(self):
         | 
| 548 | 
            +
                    if self.frozen_stages >= 0:
         | 
| 549 | 
            +
                        self.patch_embed.eval()
         | 
| 550 | 
            +
                        for param in self.patch_embed.parameters():
         | 
| 551 | 
            +
                            param.requires_grad = False
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    if self.frozen_stages >= 1 and self.ape:
         | 
| 554 | 
            +
                        self.absolute_pos_embed.requires_grad = False
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    if self.frozen_stages >= 2:
         | 
| 557 | 
            +
                        self.pos_drop.eval()
         | 
| 558 | 
            +
                        for i in range(0, self.frozen_stages - 1):
         | 
| 559 | 
            +
                            m = self.layers[i]
         | 
| 560 | 
            +
                            m.eval()
         | 
| 561 | 
            +
                            for param in m.parameters():
         | 
| 562 | 
            +
                                param.requires_grad = False
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                def init_weights(self, pretrained=None):
         | 
| 565 | 
            +
                    """Initialize the weights in backbone.
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                    Args:
         | 
| 568 | 
            +
                        pretrained (str, optional): Path to pre-trained weights.
         | 
| 569 | 
            +
                            Defaults to None.
         | 
| 570 | 
            +
                    """
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    def _init_weights(m):
         | 
| 573 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 574 | 
            +
                            trunc_normal_(m.weight, std=.02)
         | 
| 575 | 
            +
                            if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 576 | 
            +
                                nn.init.constant_(m.bias, 0)
         | 
| 577 | 
            +
                        elif isinstance(m, nn.LayerNorm):
         | 
| 578 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 579 | 
            +
                            nn.init.constant_(m.weight, 1.0)
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    if isinstance(pretrained, str):
         | 
| 582 | 
            +
                        self.apply(_init_weights)
         | 
| 583 | 
            +
                        # logger = get_root_logger()
         | 
| 584 | 
            +
                        load_checkpoint(self, pretrained, strict=False)
         | 
| 585 | 
            +
                    elif pretrained is None:
         | 
| 586 | 
            +
                        self.apply(_init_weights)
         | 
| 587 | 
            +
                    else:
         | 
| 588 | 
            +
                        raise TypeError('pretrained must be a str or None')
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                def forward(self, x):
         | 
| 591 | 
            +
                    """Forward function."""
         | 
| 592 | 
            +
                    x = self.patch_embed(x)
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    Wh, Ww = x.size(2), x.size(3)
         | 
| 595 | 
            +
                    if self.ape:
         | 
| 596 | 
            +
                        # interpolate the position embedding to the corresponding size
         | 
| 597 | 
            +
                        absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
         | 
| 598 | 
            +
                        x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
         | 
| 599 | 
            +
                    else:
         | 
| 600 | 
            +
                        x = x.flatten(2).transpose(1, 2)
         | 
| 601 | 
            +
                    x = self.pos_drop(x)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    outs = []
         | 
| 604 | 
            +
                    for i in range(self.num_layers):
         | 
| 605 | 
            +
                        layer = self.layers[i]
         | 
| 606 | 
            +
                        x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                        if i in self.out_indices:
         | 
| 609 | 
            +
                            norm_layer = getattr(self, f'norm{i}')
         | 
| 610 | 
            +
                            x_out = norm_layer(x_out)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                            out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
         | 
| 613 | 
            +
                            outs.append(out)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    return tuple(outs)
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                def train(self, mode=True):
         | 
| 618 | 
            +
                    """Convert the model into training mode while keep layers freezed."""
         | 
| 619 | 
            +
                    super(SwinTransformer, self).train(mode)
         | 
| 620 | 
            +
                    self._freeze_stages()
         | 
    	
        iebins/networks/uper_crf_head.py
    ADDED
    
    | @@ -0,0 +1,364 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from mmcv.cnn import ConvModule
         | 
| 6 | 
            +
            from .newcrf_utils import resize, normal_init
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class PPM(nn.ModuleList):
         | 
| 10 | 
            +
                """Pooling Pyramid Module used in PSPNet.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Args:
         | 
| 13 | 
            +
                    pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
         | 
| 14 | 
            +
                        Module.
         | 
| 15 | 
            +
                    in_channels (int): Input channels.
         | 
| 16 | 
            +
                    channels (int): Channels after modules, before conv_seg.
         | 
| 17 | 
            +
                    conv_cfg (dict|None): Config of conv layers.
         | 
| 18 | 
            +
                    norm_cfg (dict|None): Config of norm layers.
         | 
| 19 | 
            +
                    act_cfg (dict): Config of activation layers.
         | 
| 20 | 
            +
                    align_corners (bool): align_corners argument of F.interpolate.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
         | 
| 24 | 
            +
                             act_cfg, align_corners):
         | 
| 25 | 
            +
                    super(PPM, self).__init__()
         | 
| 26 | 
            +
                    self.pool_scales = pool_scales
         | 
| 27 | 
            +
                    self.align_corners = align_corners
         | 
| 28 | 
            +
                    self.in_channels = in_channels
         | 
| 29 | 
            +
                    self.channels = channels
         | 
| 30 | 
            +
                    self.conv_cfg = conv_cfg
         | 
| 31 | 
            +
                    self.norm_cfg = norm_cfg
         | 
| 32 | 
            +
                    self.act_cfg = act_cfg
         | 
| 33 | 
            +
                    for pool_scale in pool_scales:
         | 
| 34 | 
            +
                        # == if batch size = 1, BN is not supported, change to GN
         | 
| 35 | 
            +
                        if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256)
         | 
| 36 | 
            +
                        self.append(
         | 
| 37 | 
            +
                            nn.Sequential(
         | 
| 38 | 
            +
                                nn.AdaptiveAvgPool2d(pool_scale),
         | 
| 39 | 
            +
                                ConvModule(
         | 
| 40 | 
            +
                                    self.in_channels,
         | 
| 41 | 
            +
                                    self.channels,
         | 
| 42 | 
            +
                                    1,
         | 
| 43 | 
            +
                                    conv_cfg=self.conv_cfg,
         | 
| 44 | 
            +
                                    norm_cfg=norm_cfg,
         | 
| 45 | 
            +
                                    act_cfg=self.act_cfg)))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def forward(self, x):
         | 
| 48 | 
            +
                    """Forward function."""
         | 
| 49 | 
            +
                    ppm_outs = []
         | 
| 50 | 
            +
                    for ppm in self:
         | 
| 51 | 
            +
                        ppm_out = ppm(x)
         | 
| 52 | 
            +
                        upsampled_ppm_out = resize(
         | 
| 53 | 
            +
                            ppm_out,
         | 
| 54 | 
            +
                            size=x.size()[2:],
         | 
| 55 | 
            +
                            mode='bilinear',
         | 
| 56 | 
            +
                            align_corners=self.align_corners)
         | 
| 57 | 
            +
                        ppm_outs.append(upsampled_ppm_out)
         | 
| 58 | 
            +
                    return ppm_outs
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class BaseDecodeHead(nn.Module):
         | 
| 62 | 
            +
                """Base class for BaseDecodeHead.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Args:
         | 
| 65 | 
            +
                    in_channels (int|Sequence[int]): Input channels.
         | 
| 66 | 
            +
                    channels (int): Channels after modules, before conv_seg.
         | 
| 67 | 
            +
                    num_classes (int): Number of classes.
         | 
| 68 | 
            +
                    dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
         | 
| 69 | 
            +
                    conv_cfg (dict|None): Config of conv layers. Default: None.
         | 
| 70 | 
            +
                    norm_cfg (dict|None): Config of norm layers. Default: None.
         | 
| 71 | 
            +
                    act_cfg (dict): Config of activation layers.
         | 
| 72 | 
            +
                        Default: dict(type='ReLU')
         | 
| 73 | 
            +
                    in_index (int|Sequence[int]): Input feature index. Default: -1
         | 
| 74 | 
            +
                    input_transform (str|None): Transformation type of input features.
         | 
| 75 | 
            +
                        Options: 'resize_concat', 'multiple_select', None.
         | 
| 76 | 
            +
                        'resize_concat': Multiple feature maps will be resize to the
         | 
| 77 | 
            +
                            same size as first one and than concat together.
         | 
| 78 | 
            +
                            Usually used in FCN head of HRNet.
         | 
| 79 | 
            +
                        'multiple_select': Multiple feature maps will be bundle into
         | 
| 80 | 
            +
                            a list and passed into decode head.
         | 
| 81 | 
            +
                        None: Only one select feature map is allowed.
         | 
| 82 | 
            +
                        Default: None.
         | 
| 83 | 
            +
                    loss_decode (dict): Config of decode loss.
         | 
| 84 | 
            +
                        Default: dict(type='CrossEntropyLoss').
         | 
| 85 | 
            +
                    ignore_index (int | None): The label index to be ignored. When using
         | 
| 86 | 
            +
                        masked BCE loss, ignore_index should be set to None. Default: 255
         | 
| 87 | 
            +
                    sampler (dict|None): The config of segmentation map sampler.
         | 
| 88 | 
            +
                        Default: None.
         | 
| 89 | 
            +
                    align_corners (bool): align_corners argument of F.interpolate.
         | 
| 90 | 
            +
                        Default: False.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def __init__(self,
         | 
| 94 | 
            +
                             in_channels,
         | 
| 95 | 
            +
                             channels,
         | 
| 96 | 
            +
                             *,
         | 
| 97 | 
            +
                             num_classes,
         | 
| 98 | 
            +
                             dropout_ratio=0.1,
         | 
| 99 | 
            +
                             conv_cfg=None,
         | 
| 100 | 
            +
                             norm_cfg=None,
         | 
| 101 | 
            +
                             act_cfg=dict(type='ReLU'),
         | 
| 102 | 
            +
                             in_index=-1,
         | 
| 103 | 
            +
                             input_transform=None,
         | 
| 104 | 
            +
                             loss_decode=dict(
         | 
| 105 | 
            +
                                 type='CrossEntropyLoss',
         | 
| 106 | 
            +
                                 use_sigmoid=False,
         | 
| 107 | 
            +
                                 loss_weight=1.0),
         | 
| 108 | 
            +
                             ignore_index=255,
         | 
| 109 | 
            +
                             sampler=None,
         | 
| 110 | 
            +
                             align_corners=False):
         | 
| 111 | 
            +
                    super(BaseDecodeHead, self).__init__()
         | 
| 112 | 
            +
                    self._init_inputs(in_channels, in_index, input_transform)
         | 
| 113 | 
            +
                    self.channels = channels
         | 
| 114 | 
            +
                    self.num_classes = num_classes
         | 
| 115 | 
            +
                    self.dropout_ratio = dropout_ratio
         | 
| 116 | 
            +
                    self.conv_cfg = conv_cfg
         | 
| 117 | 
            +
                    self.norm_cfg = norm_cfg
         | 
| 118 | 
            +
                    self.act_cfg = act_cfg
         | 
| 119 | 
            +
                    self.in_index = in_index
         | 
| 120 | 
            +
                    # self.loss_decode = build_loss(loss_decode)
         | 
| 121 | 
            +
                    self.ignore_index = ignore_index
         | 
| 122 | 
            +
                    self.align_corners = align_corners
         | 
| 123 | 
            +
                    # if sampler is not None:
         | 
| 124 | 
            +
                    #     self.sampler = build_pixel_sampler(sampler, context=self)
         | 
| 125 | 
            +
                    # else:
         | 
| 126 | 
            +
                    #     self.sampler = None
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
         | 
| 129 | 
            +
                    # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1)
         | 
| 130 | 
            +
                    if dropout_ratio > 0:
         | 
| 131 | 
            +
                        self.dropout = nn.Dropout2d(dropout_ratio)
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        self.dropout = None
         | 
| 134 | 
            +
                    self.fp16_enabled = False
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def extra_repr(self):
         | 
| 137 | 
            +
                    """Extra repr."""
         | 
| 138 | 
            +
                    s = f'input_transform={self.input_transform}, ' \
         | 
| 139 | 
            +
                        f'ignore_index={self.ignore_index}, ' \
         | 
| 140 | 
            +
                        f'align_corners={self.align_corners}'
         | 
| 141 | 
            +
                    return s
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def _init_inputs(self, in_channels, in_index, input_transform):
         | 
| 144 | 
            +
                    """Check and initialize input transforms.
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    The in_channels, in_index and input_transform must match.
         | 
| 147 | 
            +
                    Specifically, when input_transform is None, only single feature map
         | 
| 148 | 
            +
                    will be selected. So in_channels and in_index must be of type int.
         | 
| 149 | 
            +
                    When input_transform
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    Args:
         | 
| 152 | 
            +
                        in_channels (int|Sequence[int]): Input channels.
         | 
| 153 | 
            +
                        in_index (int|Sequence[int]): Input feature index.
         | 
| 154 | 
            +
                        input_transform (str|None): Transformation type of input features.
         | 
| 155 | 
            +
                            Options: 'resize_concat', 'multiple_select', None.
         | 
| 156 | 
            +
                            'resize_concat': Multiple feature maps will be resize to the
         | 
| 157 | 
            +
                                same size as first one and than concat together.
         | 
| 158 | 
            +
                                Usually used in FCN head of HRNet.
         | 
| 159 | 
            +
                            'multiple_select': Multiple feature maps will be bundle into
         | 
| 160 | 
            +
                                a list and passed into decode head.
         | 
| 161 | 
            +
                            None: Only one select feature map is allowed.
         | 
| 162 | 
            +
                    """
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if input_transform is not None:
         | 
| 165 | 
            +
                        assert input_transform in ['resize_concat', 'multiple_select']
         | 
| 166 | 
            +
                    self.input_transform = input_transform
         | 
| 167 | 
            +
                    self.in_index = in_index
         | 
| 168 | 
            +
                    if input_transform is not None:
         | 
| 169 | 
            +
                        assert isinstance(in_channels, (list, tuple))
         | 
| 170 | 
            +
                        assert isinstance(in_index, (list, tuple))
         | 
| 171 | 
            +
                        assert len(in_channels) == len(in_index)
         | 
| 172 | 
            +
                        if input_transform == 'resize_concat':
         | 
| 173 | 
            +
                            self.in_channels = sum(in_channels)
         | 
| 174 | 
            +
                        else:
         | 
| 175 | 
            +
                            self.in_channels = in_channels
         | 
| 176 | 
            +
                    else:
         | 
| 177 | 
            +
                        assert isinstance(in_channels, int)
         | 
| 178 | 
            +
                        assert isinstance(in_index, int)
         | 
| 179 | 
            +
                        self.in_channels = in_channels
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def init_weights(self):
         | 
| 182 | 
            +
                    """Initialize weights of classification layer."""
         | 
| 183 | 
            +
                    # normal_init(self.conv_seg, mean=0, std=0.01)
         | 
| 184 | 
            +
                    # normal_init(self.conv1, mean=0, std=0.01)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def _transform_inputs(self, inputs):
         | 
| 187 | 
            +
                    """Transform inputs for decoder.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    Args:
         | 
| 190 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    Returns:
         | 
| 193 | 
            +
                        Tensor: The transformed inputs
         | 
| 194 | 
            +
                    """
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    if self.input_transform == 'resize_concat':
         | 
| 197 | 
            +
                        inputs = [inputs[i] for i in self.in_index]
         | 
| 198 | 
            +
                        upsampled_inputs = [
         | 
| 199 | 
            +
                            resize(
         | 
| 200 | 
            +
                                input=x,
         | 
| 201 | 
            +
                                size=inputs[0].shape[2:],
         | 
| 202 | 
            +
                                mode='bilinear',
         | 
| 203 | 
            +
                                align_corners=self.align_corners) for x in inputs
         | 
| 204 | 
            +
                        ]
         | 
| 205 | 
            +
                        inputs = torch.cat(upsampled_inputs, dim=1)
         | 
| 206 | 
            +
                    elif self.input_transform == 'multiple_select':
         | 
| 207 | 
            +
                        inputs = [inputs[i] for i in self.in_index]
         | 
| 208 | 
            +
                    else:
         | 
| 209 | 
            +
                        inputs = inputs[self.in_index]
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    return inputs
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def forward(self, inputs):
         | 
| 214 | 
            +
                    """Placeholder of forward function."""
         | 
| 215 | 
            +
                    pass
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
         | 
| 218 | 
            +
                    """Forward function for training.
         | 
| 219 | 
            +
                    Args:
         | 
| 220 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 221 | 
            +
                        img_metas (list[dict]): List of image info dict where each dict
         | 
| 222 | 
            +
                            has: 'img_shape', 'scale_factor', 'flip', and may also contain
         | 
| 223 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 224 | 
            +
                            For details on the values of these keys see
         | 
| 225 | 
            +
                            `mmseg/datasets/pipelines/formatting.py:Collect`.
         | 
| 226 | 
            +
                        gt_semantic_seg (Tensor): Semantic segmentation masks
         | 
| 227 | 
            +
                            used if the architecture supports semantic segmentation task.
         | 
| 228 | 
            +
                        train_cfg (dict): The training config.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    Returns:
         | 
| 231 | 
            +
                        dict[str, Tensor]: a dictionary of loss components
         | 
| 232 | 
            +
                    """
         | 
| 233 | 
            +
                    seg_logits = self.forward(inputs)
         | 
| 234 | 
            +
                    losses = self.losses(seg_logits, gt_semantic_seg)
         | 
| 235 | 
            +
                    return losses
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def forward_test(self, inputs, img_metas, test_cfg):
         | 
| 238 | 
            +
                    """Forward function for testing.
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    Args:
         | 
| 241 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 242 | 
            +
                        img_metas (list[dict]): List of image info dict where each dict
         | 
| 243 | 
            +
                            has: 'img_shape', 'scale_factor', 'flip', and may also contain
         | 
| 244 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 245 | 
            +
                            For details on the values of these keys see
         | 
| 246 | 
            +
                            `mmseg/datasets/pipelines/formatting.py:Collect`.
         | 
| 247 | 
            +
                        test_cfg (dict): The testing config.
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    Returns:
         | 
| 250 | 
            +
                        Tensor: Output segmentation map.
         | 
| 251 | 
            +
                    """
         | 
| 252 | 
            +
                    return self.forward(inputs)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            class UPerHead(BaseDecodeHead):
         | 
| 256 | 
            +
                def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
         | 
| 257 | 
            +
                    super(UPerHead, self).__init__(
         | 
| 258 | 
            +
                        input_transform='multiple_select', **kwargs)
         | 
| 259 | 
            +
                    # FPN Module
         | 
| 260 | 
            +
                    self.lateral_convs = nn.ModuleList()
         | 
| 261 | 
            +
                    self.fpn_convs = nn.ModuleList()
         | 
| 262 | 
            +
                    for in_channels in self.in_channels:  # skip the top layer
         | 
| 263 | 
            +
                        l_conv = ConvModule(
         | 
| 264 | 
            +
                            in_channels,
         | 
| 265 | 
            +
                            self.channels,
         | 
| 266 | 
            +
                            1,
         | 
| 267 | 
            +
                            conv_cfg=self.conv_cfg,
         | 
| 268 | 
            +
                            norm_cfg=self.norm_cfg,
         | 
| 269 | 
            +
                            act_cfg=self.act_cfg,
         | 
| 270 | 
            +
                            inplace=True)
         | 
| 271 | 
            +
                        fpn_conv = ConvModule(
         | 
| 272 | 
            +
                            self.channels,
         | 
| 273 | 
            +
                            self.channels,
         | 
| 274 | 
            +
                            3,
         | 
| 275 | 
            +
                            padding=1,
         | 
| 276 | 
            +
                            conv_cfg=self.conv_cfg,
         | 
| 277 | 
            +
                            norm_cfg=self.norm_cfg,
         | 
| 278 | 
            +
                            act_cfg=self.act_cfg,
         | 
| 279 | 
            +
                            inplace=True)
         | 
| 280 | 
            +
                        self.lateral_convs.append(l_conv)
         | 
| 281 | 
            +
                        self.fpn_convs.append(fpn_conv)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                def forward(self, inputs):
         | 
| 284 | 
            +
                    """Forward function."""
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    inputs = self._transform_inputs(inputs)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # build laterals
         | 
| 289 | 
            +
                    laterals = [
         | 
| 290 | 
            +
                        lateral_conv(inputs[i])
         | 
| 291 | 
            +
                        for i, lateral_conv in enumerate(self.lateral_convs)
         | 
| 292 | 
            +
                    ]
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # laterals.append(self.psp_forward(inputs))
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    # build top-down path
         | 
| 297 | 
            +
                    used_backbone_levels = len(laterals)
         | 
| 298 | 
            +
                    for i in range(used_backbone_levels - 1, 0, -1):
         | 
| 299 | 
            +
                        prev_shape = laterals[i - 1].shape[2:]
         | 
| 300 | 
            +
                        laterals[i - 1] += resize(
         | 
| 301 | 
            +
                            laterals[i],
         | 
| 302 | 
            +
                            size=prev_shape,
         | 
| 303 | 
            +
                            mode='bilinear',
         | 
| 304 | 
            +
                            align_corners=self.align_corners)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # build outputs
         | 
| 307 | 
            +
                    fpn_outs = [
         | 
| 308 | 
            +
                        self.fpn_convs[i](laterals[i])
         | 
| 309 | 
            +
                        for i in range(used_backbone_levels - 1)
         | 
| 310 | 
            +
                    ]
         | 
| 311 | 
            +
                    # append psp feature
         | 
| 312 | 
            +
                    fpn_outs.append(laterals[-1])
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    return fpn_outs[0]
         | 
| 315 | 
            +
             | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            class PSP(BaseDecodeHead):
         | 
| 319 | 
            +
                """Unified Perceptual Parsing for Scene Understanding.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                This head is the implementation of `UPerNet
         | 
| 322 | 
            +
                <https://arxiv.org/abs/1807.10221>`_.
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                Args:
         | 
| 325 | 
            +
                    pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
         | 
| 326 | 
            +
                        Module applied on the last feature. Default: (1, 2, 3, 6).
         | 
| 327 | 
            +
                """
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
         | 
| 330 | 
            +
                    super(PSP, self).__init__(
         | 
| 331 | 
            +
                        input_transform='multiple_select', **kwargs)
         | 
| 332 | 
            +
                    # PSP Module
         | 
| 333 | 
            +
                    self.psp_modules = PPM(
         | 
| 334 | 
            +
                        pool_scales,
         | 
| 335 | 
            +
                        self.in_channels[-1],
         | 
| 336 | 
            +
                        self.channels,
         | 
| 337 | 
            +
                        conv_cfg=self.conv_cfg,
         | 
| 338 | 
            +
                        norm_cfg=self.norm_cfg,
         | 
| 339 | 
            +
                        act_cfg=self.act_cfg,
         | 
| 340 | 
            +
                        align_corners=self.align_corners)
         | 
| 341 | 
            +
                    self.bottleneck = ConvModule(
         | 
| 342 | 
            +
                        self.in_channels[-1] + len(pool_scales) * self.channels,
         | 
| 343 | 
            +
                        self.channels,
         | 
| 344 | 
            +
                        3,
         | 
| 345 | 
            +
                        padding=1,
         | 
| 346 | 
            +
                        conv_cfg=self.conv_cfg,
         | 
| 347 | 
            +
                        norm_cfg=self.norm_cfg,
         | 
| 348 | 
            +
                        act_cfg=self.act_cfg)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def psp_forward(self, inputs):
         | 
| 351 | 
            +
                    """Forward function of PSP module."""
         | 
| 352 | 
            +
                    x = inputs[-1]
         | 
| 353 | 
            +
                    psp_outs = [x]
         | 
| 354 | 
            +
                    psp_outs.extend(self.psp_modules(x))
         | 
| 355 | 
            +
                    psp_outs = torch.cat(psp_outs, dim=1)
         | 
| 356 | 
            +
                    output = self.bottleneck(psp_outs)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    return output
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                def forward(self, inputs):
         | 
| 361 | 
            +
                    """Forward function."""
         | 
| 362 | 
            +
                    inputs = self._transform_inputs(inputs)
         | 
| 363 | 
            +
                    
         | 
| 364 | 
            +
                    return self.psp_forward(inputs)
         | 
    	
        iebins/sum_depth.py
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class Sum_depth(nn.Module):
         | 
| 6 | 
            +
                def __init__(self):
         | 
| 7 | 
            +
                    super(Sum_depth, self).__init__()
         | 
| 8 | 
            +
                    self.sum_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
         | 
| 9 | 
            +
                    sum_k = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
         | 
| 10 | 
            +
                    
         | 
| 11 | 
            +
                    sum_k = torch.from_numpy(sum_k).float().view(1, 1, 3, 3)
         | 
| 12 | 
            +
                    self.sum_conv.weight = nn.Parameter(sum_k)
         | 
| 13 | 
            +
                    
         | 
| 14 | 
            +
                    for param in self.parameters():
         | 
| 15 | 
            +
                        param.requires_grad = False
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def forward(self, x):
         | 
| 18 | 
            +
                    out = self.sum_conv(x) 
         | 
| 19 | 
            +
                    out = out.contiguous().view(-1, 1, x.size(2), x.size(3))
         | 
| 20 | 
            +
              
         | 
| 21 | 
            +
                    return out
         | 
| 22 | 
            +
             | 
    	
        iebins/test.py
    ADDED
    
    | @@ -0,0 +1,209 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import absolute_import, division, print_function
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from torch.autograd import Variable
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import os, sys, errno
         | 
| 8 | 
            +
            import argparse
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import cv2
         | 
| 12 | 
            +
            import matplotlib.pyplot as plt
         | 
| 13 | 
            +
            from tqdm import tqdm
         | 
| 14 | 
            +
            import open3d as o3d
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from utils import post_process_depth, D_to_cloud, flip_lr, inv_normalize
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from networks.NewCRFDepth import NewCRFDepth
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def convert_arg_line_to_args(arg_line):
         | 
| 22 | 
            +
                for arg in arg_line.split():
         | 
| 23 | 
            +
                    if not arg.strip():
         | 
| 24 | 
            +
                        continue
         | 
| 25 | 
            +
                    yield arg
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
         | 
| 29 | 
            +
            parser.convert_arg_line_to_args = convert_arg_line_to_args
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            parser.add_argument('--model_name', type=str, help='model name', default='iebins')
         | 
| 32 | 
            +
            parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
         | 
| 33 | 
            +
            parser.add_argument('--data_path', type=str, help='path to the data', required=True)
         | 
| 34 | 
            +
            parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
         | 
| 35 | 
            +
            parser.add_argument('--input_height', type=int, help='input height', default=480)
         | 
| 36 | 
            +
            parser.add_argument('--input_width', type=int, help='input width', default=640)
         | 
| 37 | 
            +
            parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
         | 
| 38 | 
            +
            parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
         | 
| 39 | 
            +
            parser.add_argument('--dataset', type=str, help='dataset to train on', default='nyu')
         | 
| 40 | 
            +
            parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
         | 
| 41 | 
            +
            parser.add_argument('--pred_clouds', help='if set, pred cloud points', action='store_true')
         | 
| 42 | 
            +
            parser.add_argument('--save_viz', help='if set, save visulization of the outputs', action='store_true')
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            if sys.argv.__len__() == 2:
         | 
| 45 | 
            +
                arg_filename_with_prefix = '@' + sys.argv[1]
         | 
| 46 | 
            +
                args = parser.parse_args([arg_filename_with_prefix])
         | 
| 47 | 
            +
            else:
         | 
| 48 | 
            +
                args = parser.parse_args()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            if args.dataset == 'kitti' or args.dataset == 'nyu':
         | 
| 51 | 
            +
                from dataloaders.dataloader import NewDataLoader
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            model_dir = os.path.dirname(args.checkpoint_path)
         | 
| 54 | 
            +
            sys.path.append(model_dir)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def get_num_lines(file_path):
         | 
| 58 | 
            +
                f = open(file_path, 'r')
         | 
| 59 | 
            +
                lines = f.readlines()
         | 
| 60 | 
            +
                f.close()
         | 
| 61 | 
            +
                return len(lines)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def test(params):
         | 
| 65 | 
            +
                """Test function."""
         | 
| 66 | 
            +
                args.mode = 'test'
         | 
| 67 | 
            +
                dataloader = NewDataLoader(args, 'test')
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth)
         | 
| 70 | 
            +
                model = torch.nn.DataParallel(model)
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                checkpoint = torch.load(args.checkpoint_path)
         | 
| 73 | 
            +
                model.load_state_dict(checkpoint['model'])
         | 
| 74 | 
            +
                model.eval()
         | 
| 75 | 
            +
                model.cuda()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                num_params = sum([np.prod(p.size()) for p in model.parameters()])
         | 
| 78 | 
            +
                print("Total number of parameters: {}".format(num_params))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                num_test_samples = get_num_lines(args.filenames_file)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                with open(args.filenames_file) as f:
         | 
| 83 | 
            +
                    lines = f.readlines()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                pred_depths = []
         | 
| 88 | 
            +
                pred_clouds = []
         | 
| 89 | 
            +
                start_time = time.time()
         | 
| 90 | 
            +
                with torch.no_grad():
         | 
| 91 | 
            +
                    for _, sample in enumerate(tqdm(dataloader.data)):
         | 
| 92 | 
            +
                        image = Variable(sample['image'].cuda())
         | 
| 93 | 
            +
                        inv_K_p = Variable(sample['inv_K_p'].cuda())
         | 
| 94 | 
            +
                        b, _, h, w = image.shape
         | 
| 95 | 
            +
                        depth_to_cloud = D_to_cloud(b, h, w).cuda()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        # Predict
         | 
| 98 | 
            +
                        pred_depths_r_list, _, _ = model(image)
         | 
| 99 | 
            +
                        post_process = True
         | 
| 100 | 
            +
                        if post_process:
         | 
| 101 | 
            +
                            image_flipped = flip_lr(image)
         | 
| 102 | 
            +
                            pred_depths_r_list_flipped, _, _ = model(image_flipped)
         | 
| 103 | 
            +
                            pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
         | 
| 104 | 
            +
                        
         | 
| 105 | 
            +
                        if args.pred_clouds:
         | 
| 106 | 
            +
                            if args.dataset == 'nyu':
         | 
| 107 | 
            +
                                color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0)[45:472, 43:608, :].reshape(-1, 3).cpu().numpy()
         | 
| 108 | 
            +
                                points = depth_to_cloud(pred_depth, inv_K_p).reshape(1, h, w, 3)[:, 45:472, 43:608, :].reshape(1, -1, 3)
         | 
| 109 | 
            +
                                points = points.cpu().numpy().squeeze()
         | 
| 110 | 
            +
                            else:
         | 
| 111 | 
            +
                                color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0).reshape(-1, 3).cpu().numpy()
         | 
| 112 | 
            +
                                points = depth_to_cloud(pred_depth, inv_K_p)
         | 
| 113 | 
            +
                                points = points.cpu().numpy().squeeze()
         | 
| 114 | 
            +
                            pc = o3d.geometry.PointCloud()
         | 
| 115 | 
            +
                            pc.points = o3d.utility.Vector3dVector(points)
         | 
| 116 | 
            +
                            pc.colors = o3d.utility.Vector3dVector(color)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                            pred_clouds.append(pc)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        pred_depth = pred_depth.cpu().numpy().squeeze()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        if args.do_kb_crop:
         | 
| 123 | 
            +
                            height, width = 352, 1216
         | 
| 124 | 
            +
                            top_margin = int(height - 352)
         | 
| 125 | 
            +
                            left_margin = int((width - 1216) / 2)
         | 
| 126 | 
            +
                            pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
         | 
| 127 | 
            +
                            pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
         | 
| 128 | 
            +
                            pred_depth = pred_depth_uncropped
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        pred_depths.append(pred_depth)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                elapsed_time = time.time() - start_time
         | 
| 133 | 
            +
                print('Elapesed time: %s' % str(elapsed_time))
         | 
| 134 | 
            +
                print('Done.')
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                save_name = 'models/result_' + args.model_name
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                print('Saving result pngs..')
         | 
| 139 | 
            +
                if not os.path.exists(save_name):
         | 
| 140 | 
            +
                    try:
         | 
| 141 | 
            +
                        os.mkdir(save_name)
         | 
| 142 | 
            +
                        os.mkdir(save_name + '/raw')
         | 
| 143 | 
            +
                        os.mkdir(save_name + '/cmap')
         | 
| 144 | 
            +
                        os.mkdir(save_name + '/rgb')
         | 
| 145 | 
            +
                        os.mkdir(save_name + '/gt')
         | 
| 146 | 
            +
                        os.mkdir(save_name + '/cloud')
         | 
| 147 | 
            +
                    except OSError as e:
         | 
| 148 | 
            +
                        if e.errno != errno.EEXIST:
         | 
| 149 | 
            +
                            raise
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                for s in tqdm(range(num_test_samples)):
         | 
| 152 | 
            +
                    if args.dataset == 'kitti':
         | 
| 153 | 
            +
                        date_drive = lines[s].split('/')[1]
         | 
| 154 | 
            +
                        filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace(
         | 
| 155 | 
            +
                            '.jpg', '.png')
         | 
| 156 | 
            +
                        filename_pred_ply = save_name + '/cloud/' + date_drive + '_' + lines[s].split()[0].split('/')[-1][:-4] + '_' + 'iebins' + '.ply'
         | 
| 157 | 
            +
                        filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[
         | 
| 158 | 
            +
                            -1].replace('.jpg', '.png')
         | 
| 159 | 
            +
                        filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
         | 
| 160 | 
            +
                    elif args.dataset == 'kittipred':
         | 
| 161 | 
            +
                        filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
         | 
| 162 | 
            +
                        filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
         | 
| 163 | 
            +
                        filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        scene_name = lines[s].split()[0].split('/')[0]
         | 
| 166 | 
            +
                        filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
         | 
| 167 | 
            +
                            '.jpg', '.png')
         | 
| 168 | 
            +
                        filename_pred_ply = save_name + '/cloud/' + scene_name + '_' + lines[s].split()[0].split('/')[1][:-4] + '_' + 'iebins' + '.ply'
         | 
| 169 | 
            +
                        filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
         | 
| 170 | 
            +
                            '.jpg', '.png')
         | 
| 171 | 
            +
                        filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
         | 
| 172 | 
            +
                            '.jpg', '_gt.png')
         | 
| 173 | 
            +
                        filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1]
         | 
| 174 | 
            +
                    
         | 
| 175 | 
            +
                    rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
         | 
| 176 | 
            +
                    image = cv2.imread(rgb_path)
         | 
| 177 | 
            +
                    if args.dataset == 'nyu':
         | 
| 178 | 
            +
                        gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
         | 
| 179 | 
            +
                        gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0  # Visualization purpose only
         | 
| 180 | 
            +
                        gt[gt == 0] = np.amax(gt)
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    pred_depth = pred_depths[s]
         | 
| 183 | 
            +
                    
         | 
| 184 | 
            +
                    if args.dataset == 'kitti' or args.dataset == 'kittipred':
         | 
| 185 | 
            +
                        pred_depth_scaled = pred_depth * 256.0
         | 
| 186 | 
            +
                    else:
         | 
| 187 | 
            +
                        pred_depth_scaled = pred_depth * 1000.0
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
         | 
| 190 | 
            +
                    cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
         | 
| 191 | 
            +
                    
         | 
| 192 | 
            +
                    if args.save_viz:
         | 
| 193 | 
            +
                        cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
         | 
| 194 | 
            +
                        if args.dataset == 'nyu':
         | 
| 195 | 
            +
                            plt.imsave(filename_gt_png, (10 - gt) / 10, cmap='jet')
         | 
| 196 | 
            +
                            pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
         | 
| 197 | 
            +
                            plt.imsave(filename_cmap_png, (10 - pred_depth) / 10, cmap='jet')
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='magma')
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    if args.pred_clouds:
         | 
| 202 | 
            +
                        pred_cloud = pred_clouds[s]
         | 
| 203 | 
            +
                        o3d.io.write_point_cloud(filename_pred_ply, pred_cloud)
         | 
| 204 | 
            +
                
         | 
| 205 | 
            +
                return
         | 
| 206 | 
            +
             | 
| 207 | 
            +
             | 
| 208 | 
            +
            if __name__ == '__main__':
         | 
| 209 | 
            +
                test(args)
         | 
    	
        iebins/train.py
    ADDED
    
    | @@ -0,0 +1,499 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.utils as utils
         | 
| 4 | 
            +
            import torch.backends.cudnn as cudnn
         | 
| 5 | 
            +
            import torch.distributed as dist
         | 
| 6 | 
            +
            import torch.multiprocessing as mp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import os, sys, time
         | 
| 9 | 
            +
            from telnetlib import IP
         | 
| 10 | 
            +
            import argparse
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from tensorboardX import SummaryWriter
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, entropy_loss, colormap, \
         | 
| 17 | 
            +
                                   block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args, colormap_magma
         | 
| 18 | 
            +
            from networks.NewCRFDepth import NewCRFDepth
         | 
| 19 | 
            +
            from networks.depth_update import *
         | 
| 20 | 
            +
            from datetime import datetime
         | 
| 21 | 
            +
            from sum_depth import Sum_depth
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
         | 
| 25 | 
            +
            parser.convert_arg_line_to_args = convert_arg_line_to_args
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            parser.add_argument('--mode',                      type=str,   help='train or test', default='train')
         | 
| 28 | 
            +
            parser.add_argument('--model_name',                type=str,   help='model name', default='iebins')
         | 
| 29 | 
            +
            parser.add_argument('--encoder',                   type=str,   help='type of encoder, base07, large07, tiny07', default='large07')
         | 
| 30 | 
            +
            parser.add_argument('--pretrain',                  type=str,   help='path of pretrained encoder', default=None)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Dataset
         | 
| 33 | 
            +
            parser.add_argument('--dataset',                   type=str,   help='dataset to train on, kitti or nyu', default='nyu')
         | 
| 34 | 
            +
            parser.add_argument('--data_path',                 type=str,   help='path to the data', required=True)
         | 
| 35 | 
            +
            parser.add_argument('--gt_path',                   type=str,   help='path to the groundtruth data', required=True)
         | 
| 36 | 
            +
            parser.add_argument('--filenames_file',            type=str,   help='path to the filenames text file', required=True)
         | 
| 37 | 
            +
            parser.add_argument('--input_height',              type=int,   help='input height', default=480)
         | 
| 38 | 
            +
            parser.add_argument('--input_width',               type=int,   help='input width',  default=640)
         | 
| 39 | 
            +
            parser.add_argument('--max_depth',                 type=float, help='maximum depth in estimation', default=10)
         | 
| 40 | 
            +
            parser.add_argument('--min_depth',                 type=float, help='minimum depth in estimation', default=0.1)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            # Log and save
         | 
| 43 | 
            +
            parser.add_argument('--log_directory',             type=str,   help='directory to save checkpoints and summaries', default='')
         | 
| 44 | 
            +
            parser.add_argument('--checkpoint_path',           type=str,   help='path to a checkpoint to load', default='')
         | 
| 45 | 
            +
            parser.add_argument('--log_freq',                  type=int,   help='Logging frequency in global steps', default=100)
         | 
| 46 | 
            +
            parser.add_argument('--save_freq',                 type=int,   help='Checkpoint saving frequency in global steps', default=5000)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # Training
         | 
| 49 | 
            +
            parser.add_argument('--weight_decay',              type=float, help='weight decay factor for optimization', default=1e-2)
         | 
| 50 | 
            +
            parser.add_argument('--retrain',                               help='if used with checkpoint_path, will restart training from step zero', action='store_true')
         | 
| 51 | 
            +
            parser.add_argument('--adam_eps',                  type=float, help='epsilon in Adam optimizer', default=1e-6)
         | 
| 52 | 
            +
            parser.add_argument('--batch_size',                type=int,   help='batch size', default=4)
         | 
| 53 | 
            +
            parser.add_argument('--num_epochs',                type=int,   help='number of epochs', default=50)
         | 
| 54 | 
            +
            parser.add_argument('--learning_rate',             type=float, help='initial learning rate', default=1e-4)
         | 
| 55 | 
            +
            parser.add_argument('--end_learning_rate',         type=float, help='end learning rate', default=-1)
         | 
| 56 | 
            +
            parser.add_argument('--variance_focus',            type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Preprocessing
         | 
| 59 | 
            +
            parser.add_argument('--do_random_rotate',                      help='if set, will perform random rotation for augmentation', action='store_true')
         | 
| 60 | 
            +
            parser.add_argument('--degree',                    type=float, help='random rotation maximum degree', default=2.5)
         | 
| 61 | 
            +
            parser.add_argument('--do_kb_crop',                            help='if set, crop input images as kitti benchmark images', action='store_true')
         | 
| 62 | 
            +
            parser.add_argument('--use_right',                             help='if set, will randomly use right images when train on KITTI', action='store_true')
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Multi-gpu training
         | 
| 65 | 
            +
            parser.add_argument('--num_threads',               type=int,   help='number of threads to use for data loading', default=1)
         | 
| 66 | 
            +
            parser.add_argument('--world_size',                type=int,   help='number of nodes for distributed training', default=1)
         | 
| 67 | 
            +
            parser.add_argument('--rank',                      type=int,   help='node rank for distributed training', default=0)
         | 
| 68 | 
            +
            parser.add_argument('--dist_url',                  type=str,   help='url used to set up distributed training', default='tcp://127.0.0.1:1234')
         | 
| 69 | 
            +
            parser.add_argument('--dist_backend',              type=str,   help='distributed backend', default='nccl')
         | 
| 70 | 
            +
            parser.add_argument('--gpu',                       type=int,   help='GPU id to use.', default=None)
         | 
| 71 | 
            +
            parser.add_argument('--multiprocessing_distributed',           help='Use multi-processing distributed training to launch '
         | 
| 72 | 
            +
                                                                                'N processes per node, which has N GPUs. This is the '
         | 
| 73 | 
            +
                                                                                'fastest way to use PyTorch for either single node or '
         | 
| 74 | 
            +
                                                                                'multi node data parallel training', action='store_true',)
         | 
| 75 | 
            +
            # Online eval
         | 
| 76 | 
            +
            parser.add_argument('--do_online_eval',                        help='if set, perform online eval in every eval_freq steps', action='store_true')
         | 
| 77 | 
            +
            parser.add_argument('--data_path_eval',            type=str,   help='path to the data for online evaluation', required=False)
         | 
| 78 | 
            +
            parser.add_argument('--gt_path_eval',              type=str,   help='path to the groundtruth data for online evaluation', required=False)
         | 
| 79 | 
            +
            parser.add_argument('--filenames_file_eval',       type=str,   help='path to the filenames text file for online evaluation', required=False)
         | 
| 80 | 
            +
            parser.add_argument('--min_depth_eval',            type=float, help='minimum depth for evaluation', default=1e-3)
         | 
| 81 | 
            +
            parser.add_argument('--max_depth_eval',            type=float, help='maximum depth for evaluation', default=80)
         | 
| 82 | 
            +
            parser.add_argument('--eigen_crop',                            help='if set, crops according to Eigen NIPS14', action='store_true')
         | 
| 83 | 
            +
            parser.add_argument('--garg_crop',                             help='if set, crops according to Garg  ECCV16', action='store_true')
         | 
| 84 | 
            +
            parser.add_argument('--eval_freq',                 type=int,   help='Online evaluation frequency in global steps', default=500)
         | 
| 85 | 
            +
            parser.add_argument('--eval_summary_directory',    type=str,   help='output directory for eval summary,'
         | 
| 86 | 
            +
                                                                                'if empty outputs to checkpoint folder', default='')
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            if sys.argv.__len__() == 2:
         | 
| 89 | 
            +
                arg_filename_with_prefix = '@' + sys.argv[1]
         | 
| 90 | 
            +
                args = parser.parse_args([arg_filename_with_prefix])
         | 
| 91 | 
            +
            else:
         | 
| 92 | 
            +
                args = parser.parse_args()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            if args.dataset == 'kitti' or args.dataset == 'nyu':
         | 
| 95 | 
            +
                from dataloaders.dataloader import NewDataLoader
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def online_eval(model, dataloader_eval, gpu, epoch, ngpus, group, post_process=False):
         | 
| 99 | 
            +
                eval_measures = torch.zeros(10).cuda(device=gpu)
         | 
| 100 | 
            +
                for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
         | 
| 101 | 
            +
                    with torch.no_grad():
         | 
| 102 | 
            +
                        image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True))
         | 
| 103 | 
            +
                        gt_depth = eval_sample_batched['depth']
         | 
| 104 | 
            +
                        has_valid_depth = eval_sample_batched['has_valid_depth']
         | 
| 105 | 
            +
                        if not has_valid_depth:
         | 
| 106 | 
            +
                            # print('Invalid depth. continue.')
         | 
| 107 | 
            +
                            continue
         | 
| 108 | 
            +
                       
         | 
| 109 | 
            +
                        pred_depths_r_list, _, _ = model(image)
         | 
| 110 | 
            +
                        if post_process:
         | 
| 111 | 
            +
                            image_flipped = flip_lr(image)
         | 
| 112 | 
            +
                            pred_depths_r_list_flipped, _, _ = model(image_flipped)
         | 
| 113 | 
            +
                            pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        pred_depth = pred_depth.cpu().numpy().squeeze()
         | 
| 116 | 
            +
                        gt_depth = gt_depth.cpu().numpy().squeeze()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    if args.do_kb_crop:
         | 
| 119 | 
            +
                        height, width = gt_depth.shape
         | 
| 120 | 
            +
                        top_margin = int(height - 352)
         | 
| 121 | 
            +
                        left_margin = int((width - 1216) / 2)
         | 
| 122 | 
            +
                        pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
         | 
| 123 | 
            +
                        pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
         | 
| 124 | 
            +
                        pred_depth = pred_depth_uncropped
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
         | 
| 127 | 
            +
                    pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
         | 
| 128 | 
            +
                    pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
         | 
| 129 | 
            +
                    pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    if args.garg_crop or args.eigen_crop:
         | 
| 134 | 
            +
                        gt_height, gt_width = gt_depth.shape
         | 
| 135 | 
            +
                        eval_mask = np.zeros(valid_mask.shape)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        if args.garg_crop:
         | 
| 138 | 
            +
                            eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                        elif args.eigen_crop:
         | 
| 141 | 
            +
                            if args.dataset == 'kitti':
         | 
| 142 | 
            +
                                eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
         | 
| 143 | 
            +
                            elif args.dataset == 'nyu':
         | 
| 144 | 
            +
                                eval_mask[45:471, 41:601] = 1
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                        valid_mask = np.logical_and(valid_mask, eval_mask)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    eval_measures[:9] += torch.tensor(measures).cuda(device=gpu)
         | 
| 151 | 
            +
                    eval_measures[9] += 1
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                if args.multiprocessing_distributed:
         | 
| 154 | 
            +
                    # group = dist.new_group([i for i in range(ngpus)])
         | 
| 155 | 
            +
                    dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                if not args.multiprocessing_distributed or gpu == 0:
         | 
| 158 | 
            +
                    eval_measures_cpu = eval_measures.cpu()
         | 
| 159 | 
            +
                    cnt = eval_measures_cpu[9].item()
         | 
| 160 | 
            +
                    eval_measures_cpu /= cnt
         | 
| 161 | 
            +
                    print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
         | 
| 162 | 
            +
                    print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
         | 
| 163 | 
            +
                                                                                                 'sq_rel', 'log_rms', 'd1', 'd2',
         | 
| 164 | 
            +
                                                                                                 'd3'))
         | 
| 165 | 
            +
                    for i in range(8):
         | 
| 166 | 
            +
                        print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
         | 
| 167 | 
            +
                    print('{:7.4f}'.format(eval_measures_cpu[8]))
         | 
| 168 | 
            +
                    return eval_measures_cpu
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                return None
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def main_worker(gpu, ngpus_per_node, args):
         | 
| 174 | 
            +
                args.gpu = gpu
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                if args.gpu is not None:
         | 
| 177 | 
            +
                    print("== Use GPU: {} for training".format(args.gpu))
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                if args.distributed:
         | 
| 180 | 
            +
                    if args.dist_url == "env://" and args.rank == -1:
         | 
| 181 | 
            +
                        args.rank = int(os.environ["RANK"])
         | 
| 182 | 
            +
                    if args.multiprocessing_distributed:
         | 
| 183 | 
            +
                        args.rank = args.rank * ngpus_per_node + gpu
         | 
| 184 | 
            +
                    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                # model
         | 
| 187 | 
            +
                model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
         | 
| 188 | 
            +
                model.train()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                num_params = sum([np.prod(p.size()) for p in model.parameters()])
         | 
| 191 | 
            +
                print("== Total number of parameters: {}".format(num_params))
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
         | 
| 194 | 
            +
                print("== Total number of learning parameters: {}".format(num_params_update))
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                if args.distributed:
         | 
| 197 | 
            +
                    if args.gpu is not None:
         | 
| 198 | 
            +
                        torch.cuda.set_device(args.gpu)
         | 
| 199 | 
            +
                        model.cuda(args.gpu)
         | 
| 200 | 
            +
                        args.batch_size = int(args.batch_size / ngpus_per_node)
         | 
| 201 | 
            +
                        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        model.cuda()
         | 
| 204 | 
            +
                        model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
         | 
| 205 | 
            +
                else:
         | 
| 206 | 
            +
                    model = torch.nn.DataParallel(model)
         | 
| 207 | 
            +
                    model.cuda()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                if args.distributed:
         | 
| 210 | 
            +
                    print("== Model Initialized on GPU: {}".format(args.gpu))
         | 
| 211 | 
            +
                else:
         | 
| 212 | 
            +
                    print("== Model Initialized")
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                global_step = 0
         | 
| 215 | 
            +
                best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
         | 
| 216 | 
            +
                best_eval_measures_higher_better = torch.zeros(3).cpu()
         | 
| 217 | 
            +
                best_eval_steps = np.zeros(9, dtype=np.int32)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                # Training parameters
         | 
| 220 | 
            +
                optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
         | 
| 221 | 
            +
                                            lr=args.learning_rate)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                model_just_loaded = False
         | 
| 224 | 
            +
                if args.checkpoint_path != '':
         | 
| 225 | 
            +
                    if os.path.isfile(args.checkpoint_path):
         | 
| 226 | 
            +
                        print("== Loading checkpoint '{}'".format(args.checkpoint_path))
         | 
| 227 | 
            +
                        if args.gpu is None:
         | 
| 228 | 
            +
                            checkpoint = torch.load(args.checkpoint_path)
         | 
| 229 | 
            +
                        else:
         | 
| 230 | 
            +
                            loc = 'cuda:{}'.format(args.gpu)
         | 
| 231 | 
            +
                            checkpoint = torch.load(args.checkpoint_path, map_location=loc)
         | 
| 232 | 
            +
                        model.load_state_dict(checkpoint['model'])
         | 
| 233 | 
            +
                        optimizer.load_state_dict(checkpoint['optimizer'])
         | 
| 234 | 
            +
                        if not args.retrain:
         | 
| 235 | 
            +
                            try:
         | 
| 236 | 
            +
                                global_step = checkpoint['global_step']
         | 
| 237 | 
            +
                                best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
         | 
| 238 | 
            +
                                best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
         | 
| 239 | 
            +
                                best_eval_steps = checkpoint['best_eval_steps']
         | 
| 240 | 
            +
                            except KeyError:
         | 
| 241 | 
            +
                                print("Could not load values for online evaluation")
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
         | 
| 244 | 
            +
                    else:
         | 
| 245 | 
            +
                        print("== No checkpoint found at '{}'".format(args.checkpoint_path))
         | 
| 246 | 
            +
                    model_just_loaded = True
         | 
| 247 | 
            +
                    del checkpoint
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                cudnn.benchmark = True
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                dataloader = NewDataLoader(args, 'train')
         | 
| 252 | 
            +
                dataloader_eval = NewDataLoader(args, 'online_eval')
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                # ===== Evaluation before training ======
         | 
| 255 | 
            +
                # model.eval()
         | 
| 256 | 
            +
                # with torch.no_grad():
         | 
| 257 | 
            +
                #     eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                # Logging
         | 
| 260 | 
            +
                if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
         | 
| 261 | 
            +
                    writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
         | 
| 262 | 
            +
                    if args.do_online_eval:
         | 
| 263 | 
            +
                        if args.eval_summary_directory != '':
         | 
| 264 | 
            +
                            eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
         | 
| 265 | 
            +
                        else:
         | 
| 266 | 
            +
                            eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')
         | 
| 267 | 
            +
                        eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                silog_criterion = silog_loss(variance_focus=args.variance_focus)
         | 
| 270 | 
            +
                sum_localdepth = Sum_depth().cuda(args.gpu)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                start_time = time.time()
         | 
| 273 | 
            +
                duration = 0
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                num_log_images = args.batch_size
         | 
| 276 | 
            +
                end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
         | 
| 279 | 
            +
                var_cnt = len(var_sum)
         | 
| 280 | 
            +
                var_sum = np.sum(var_sum)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                steps_per_epoch = len(dataloader.data)
         | 
| 285 | 
            +
                num_total_steps = args.num_epochs * steps_per_epoch
         | 
| 286 | 
            +
                epoch = global_step // steps_per_epoch
         | 
| 287 | 
            +
                
         | 
| 288 | 
            +
                group = dist.new_group([i for i in range(ngpus_per_node)])
         | 
| 289 | 
            +
                while epoch < args.num_epochs:
         | 
| 290 | 
            +
                    if args.distributed:
         | 
| 291 | 
            +
                        dataloader.train_sampler.set_epoch(epoch)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    for step, sample_batched in enumerate(dataloader.data):
         | 
| 294 | 
            +
                        optimizer.zero_grad()
         | 
| 295 | 
            +
                        before_op_time = time.time()
         | 
| 296 | 
            +
                        si_loss = 0
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
         | 
| 299 | 
            +
                        depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                        pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = model(image, epoch, step)
         | 
| 302 | 
            +
                        
         | 
| 303 | 
            +
                        if args.dataset == 'nyu':
         | 
| 304 | 
            +
                            mask = depth_gt > 0.1
         | 
| 305 | 
            +
                        else:
         | 
| 306 | 
            +
                            mask = depth_gt > 1.0
         | 
| 307 | 
            +
                        
         | 
| 308 | 
            +
                        max_tree_depth = len(pred_depths_r_list)         
         | 
| 309 | 
            +
                        for curr_tree_depth in range(max_tree_depth):
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                            si_loss += silog_criterion.forward(pred_depths_r_list[curr_tree_depth], depth_gt, mask.to(torch.bool))
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                        loss = si_loss
         | 
| 314 | 
            +
                            
         | 
| 315 | 
            +
                        loss.backward()
         | 
| 316 | 
            +
                        for param_group in optimizer.param_groups:
         | 
| 317 | 
            +
                            current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
         | 
| 318 | 
            +
                            param_group['lr'] = current_lr
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                        optimizer.step()
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
         | 
| 323 | 
            +
                            print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
         | 
| 324 | 
            +
                            # if np.isnan(loss.cpu().item()):
         | 
| 325 | 
            +
                            #     print('NaN in loss occurred. Aborting training.')
         | 
| 326 | 
            +
                            #     return -1
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        duration += time.time() - before_op_time
         | 
| 329 | 
            +
                        if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
         | 
| 330 | 
            +
                            var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
         | 
| 331 | 
            +
                            var_cnt = len(var_sum)
         | 
| 332 | 
            +
                            var_sum = np.sum(var_sum)
         | 
| 333 | 
            +
                            examples_per_sec = args.batch_size / duration * args.log_freq
         | 
| 334 | 
            +
                            duration = 0
         | 
| 335 | 
            +
                            time_sofar = (time.time() - start_time) / 3600
         | 
| 336 | 
            +
                            training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
         | 
| 337 | 
            +
                            if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
         | 
| 338 | 
            +
                                print("{}".format(args.model_name))
         | 
| 339 | 
            +
                            print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
         | 
| 340 | 
            +
                            print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                            if not args.multiprocessing_distributed or (args.multiprocessing_distributed
         | 
| 343 | 
            +
                                                                        and args.rank % ngpus_per_node == 0):
         | 
| 344 | 
            +
                                writer.add_scalar('silog_loss', si_loss, global_step)
         | 
| 345 | 
            +
                                # writer.add_scalar('var_loss', var_loss, global_step)
         | 
| 346 | 
            +
                                writer.add_scalar('learning_rate', current_lr, global_step)
         | 
| 347 | 
            +
                                writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
         | 
| 348 | 
            +
                                depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e-3, depth_gt)
         | 
| 349 | 
            +
                                for i in range(num_log_images):
         | 
| 350 | 
            +
                                    if args.dataset == 'nyu':
         | 
| 351 | 
            +
                                        writer.add_image('depth_gt/image/{}'.format(i), colormap(depth_gt[i, :, :, :].data), global_step)
         | 
| 352 | 
            +
                                        writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)                            
         | 
| 353 | 
            +
                                        writer.add_image('depth_r_est0/image/{}'.format(i), colormap(pred_depths_r_list[0][i, :, :, :].data), global_step)
         | 
| 354 | 
            +
                                        writer.add_image('depth_r_est1/image/{}'.format(i), colormap(pred_depths_r_list[1][i, :, :, :].data), global_step)
         | 
| 355 | 
            +
                                        writer.add_image('depth_r_est2/image/{}'.format(i), colormap(pred_depths_r_list[2][i, :, :, :].data), global_step)
         | 
| 356 | 
            +
                                        writer.add_image('depth_r_est3/image/{}'.format(i), colormap(pred_depths_r_list[3][i, :, :, :].data), global_step)
         | 
| 357 | 
            +
                                        writer.add_image('depth_r_est4/image/{}'.format(i), colormap(pred_depths_r_list[4][i, :, :, :].data), global_step)
         | 
| 358 | 
            +
                                        writer.add_image('depth_r_est5/image/{}'.format(i), colormap(pred_depths_r_list[5][i, :, :, :].data), global_step)
         | 
| 359 | 
            +
                                        writer.add_image('depth_c_est0/image/{}'.format(i), colormap(pred_depths_c_list[0][i, :, :, :].data), global_step)
         | 
| 360 | 
            +
                                        writer.add_image('depth_c_est1/image/{}'.format(i), colormap(pred_depths_c_list[1][i, :, :, :].data), global_step)
         | 
| 361 | 
            +
                                        writer.add_image('depth_c_est2/image/{}'.format(i), colormap(pred_depths_c_list[2][i, :, :, :].data), global_step)
         | 
| 362 | 
            +
                                        writer.add_image('depth_c_est3/image/{}'.format(i), colormap(pred_depths_c_list[3][i, :, :, :].data), global_step)
         | 
| 363 | 
            +
                                        writer.add_image('depth_c_est4/image/{}'.format(i), colormap(pred_depths_c_list[4][i, :, :, :].data), global_step)
         | 
| 364 | 
            +
                                        writer.add_image('depth_c_est5/image/{}'.format(i), colormap(pred_depths_c_list[5][i, :, :, :].data), global_step)
         | 
| 365 | 
            +
                                    else:
         | 
| 366 | 
            +
                                        writer.add_image('depth_gt/image/{}'.format(i), colormap_magma(torch.log10(depth_gt[i, :, :, :].data)), global_step)
         | 
| 367 | 
            +
                                        writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)                            
         | 
| 368 | 
            +
                                        writer.add_image('depth_r_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[0][i, :, :, :].data)), global_step)
         | 
| 369 | 
            +
                                        writer.add_image('depth_r_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[1][i, :, :, :].data)), global_step)
         | 
| 370 | 
            +
                                        writer.add_image('depth_r_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[2][i, :, :, :].data)), global_step)
         | 
| 371 | 
            +
                                        writer.add_image('depth_r_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[3][i, :, :, :].data)), global_step)
         | 
| 372 | 
            +
                                        writer.add_image('depth_r_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[4][i, :, :, :].data)), global_step)
         | 
| 373 | 
            +
                                        writer.add_image('depth_r_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[5][i, :, :, :].data)), global_step)
         | 
| 374 | 
            +
                                        writer.add_image('depth_c_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[0][i, :, :, :].data)), global_step)
         | 
| 375 | 
            +
                                        writer.add_image('depth_c_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[1][i, :, :, :].data)), global_step)
         | 
| 376 | 
            +
                                        writer.add_image('depth_c_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[2][i, :, :, :].data)), global_step)
         | 
| 377 | 
            +
                                        writer.add_image('depth_c_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[3][i, :, :, :].data)), global_step)
         | 
| 378 | 
            +
                                        writer.add_image('depth_c_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[4][i, :, :, :].data)), global_step)
         | 
| 379 | 
            +
                                        writer.add_image('depth_c_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[5][i, :, :, :].data)), global_step)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                                    writer.add_image('uncer_est0/image/{}'.format(i), colormap(uncertainty_maps_list[0][i, :, :, :].data), global_step)
         | 
| 382 | 
            +
                                    writer.add_image('uncer_est1/image/{}'.format(i), colormap(uncertainty_maps_list[1][i, :, :, :].data), global_step)
         | 
| 383 | 
            +
                                    writer.add_image('uncer_est2/image/{}'.format(i), colormap(uncertainty_maps_list[2][i, :, :, :].data), global_step)
         | 
| 384 | 
            +
                                    writer.add_image('uncer_est3/image/{}'.format(i), colormap(uncertainty_maps_list[3][i, :, :, :].data), global_step)
         | 
| 385 | 
            +
                                    writer.add_image('uncer_est4/image/{}'.format(i), colormap(uncertainty_maps_list[4][i, :, :, :].data), global_step)
         | 
| 386 | 
            +
                                    writer.add_image('uncer_est5/image/{}'.format(i), colormap(uncertainty_maps_list[5][i, :, :, :].data), global_step)
         | 
| 387 | 
            +
                                           
         | 
| 388 | 
            +
                        if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
         | 
| 389 | 
            +
                            time.sleep(0.1)
         | 
| 390 | 
            +
                            model.eval()
         | 
| 391 | 
            +
                            with torch.no_grad():
         | 
| 392 | 
            +
                                eval_measures = online_eval(model, dataloader_eval, gpu, epoch, ngpus_per_node, group, post_process=True)
         | 
| 393 | 
            +
                            if eval_measures is not None:
         | 
| 394 | 
            +
                                exp_name = '%s'%(datetime.now().strftime('%m%d'))
         | 
| 395 | 
            +
                                log_txt = os.path.join(args.log_directory + '/' + args.model_name, exp_name+'_logs.txt')
         | 
| 396 | 
            +
                                with open(log_txt, 'a') as txtfile:
         | 
| 397 | 
            +
                                    txtfile.write(">>>>>>>>>>>>>>>>>>>>>>>>>Step:%d>>>>>>>>>>>>>>>>>>>>>>>>>\n"%(int(global_step)))
         | 
| 398 | 
            +
                                    txtfile.write("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}\n".format('silog', 
         | 
| 399 | 
            +
                                                    'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2','d3'))
         | 
| 400 | 
            +
                                    txtfile.write("depth estimation\n")
         | 
| 401 | 
            +
                                    line = ''
         | 
| 402 | 
            +
                                    for i in range(9):
         | 
| 403 | 
            +
                                        line +='{:7.4f}, '.format(eval_measures[i])
         | 
| 404 | 
            +
                                    txtfile.write(line+'\n')
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                                for i in range(9):
         | 
| 407 | 
            +
                                    eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
         | 
| 408 | 
            +
                                    measure = eval_measures[i]
         | 
| 409 | 
            +
                                    is_best = False
         | 
| 410 | 
            +
                                    if i < 6 and measure < best_eval_measures_lower_better[i]:
         | 
| 411 | 
            +
                                        old_best = best_eval_measures_lower_better[i].item()
         | 
| 412 | 
            +
                                        best_eval_measures_lower_better[i] = measure.item()
         | 
| 413 | 
            +
                                        is_best = True
         | 
| 414 | 
            +
                                    elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
         | 
| 415 | 
            +
                                        old_best = best_eval_measures_higher_better[i-6].item()
         | 
| 416 | 
            +
                                        best_eval_measures_higher_better[i-6] = measure.item()
         | 
| 417 | 
            +
                                        is_best = True
         | 
| 418 | 
            +
                                    if is_best:
         | 
| 419 | 
            +
                                        old_best_step = best_eval_steps[i]
         | 
| 420 | 
            +
                                        old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
         | 
| 421 | 
            +
                                        model_path = args.log_directory + '/' + args.model_name + old_best_name
         | 
| 422 | 
            +
                                        if os.path.exists(model_path):
         | 
| 423 | 
            +
                                            command = 'rm {}'.format(model_path)
         | 
| 424 | 
            +
                                            os.system(command)
         | 
| 425 | 
            +
                                        best_eval_steps[i] = global_step
         | 
| 426 | 
            +
                                        model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
         | 
| 427 | 
            +
                                        print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
         | 
| 428 | 
            +
                                        checkpoint = {'global_step': global_step,
         | 
| 429 | 
            +
                                                      'model': model.state_dict(),
         | 
| 430 | 
            +
                                                      'optimizer': optimizer.state_dict(),
         | 
| 431 | 
            +
                                                      'best_eval_measures_higher_better': best_eval_measures_higher_better,
         | 
| 432 | 
            +
                                                      'best_eval_measures_lower_better': best_eval_measures_lower_better,
         | 
| 433 | 
            +
                                                      'best_eval_steps': best_eval_steps
         | 
| 434 | 
            +
                                                      }
         | 
| 435 | 
            +
                                        torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
         | 
| 436 | 
            +
                                eval_summary_writer.flush()
         | 
| 437 | 
            +
                            model.train()
         | 
| 438 | 
            +
                            block_print()
         | 
| 439 | 
            +
                            enable_print()
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                        model_just_loaded = False
         | 
| 442 | 
            +
                        global_step += 1
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    epoch += 1
         | 
| 445 | 
            +
                   
         | 
| 446 | 
            +
                if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
         | 
| 447 | 
            +
                    writer.close()
         | 
| 448 | 
            +
                    if args.do_online_eval:
         | 
| 449 | 
            +
                        eval_summary_writer.close()
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            def main():
         | 
| 453 | 
            +
                if args.mode != 'train':
         | 
| 454 | 
            +
                    print('train.py is only for training.')
         | 
| 455 | 
            +
                    return -1
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                exp_name = '%s'%(datetime.now().strftime('%m%d'))  
         | 
| 458 | 
            +
                args.log_directory = os.path.join(args.log_directory,exp_name)  
         | 
| 459 | 
            +
                command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)
         | 
| 460 | 
            +
                os.system(command)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                args_out_path = os.path.join(args.log_directory, args.model_name)
         | 
| 463 | 
            +
                command = 'cp ' + sys.argv[1] + ' ' + args_out_path
         | 
| 464 | 
            +
                os.system(command)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                save_files = True
         | 
| 467 | 
            +
                if save_files:
         | 
| 468 | 
            +
                    aux_out_path = os.path.join(args.log_directory, args.model_name)
         | 
| 469 | 
            +
                    networks_savepath = os.path.join(aux_out_path, 'networks')
         | 
| 470 | 
            +
                    dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
         | 
| 471 | 
            +
                    command = 'cp iebins/train.py ' + aux_out_path
         | 
| 472 | 
            +
                    os.system(command)
         | 
| 473 | 
            +
                    command = 'mkdir -p ' + networks_savepath + ' && cp iebins/networks/*.py ' + networks_savepath
         | 
| 474 | 
            +
                    os.system(command)
         | 
| 475 | 
            +
                    command = 'mkdir -p ' + dataloaders_savepath + ' && cp iebins/dataloaders/*.py ' + dataloaders_savepath
         | 
| 476 | 
            +
                    os.system(command)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                torch.cuda.empty_cache()
         | 
| 479 | 
            +
                args.distributed = args.world_size > 1 or args.multiprocessing_distributed
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                ngpus_per_node = torch.cuda.device_count()
         | 
| 482 | 
            +
                if ngpus_per_node > 1 and not args.multiprocessing_distributed:
         | 
| 483 | 
            +
                    print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
         | 
| 484 | 
            +
                    return -1
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                if args.do_online_eval:
         | 
| 487 | 
            +
                    print("You have specified --do_online_eval.")
         | 
| 488 | 
            +
                    print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
         | 
| 489 | 
            +
                          .format(args.eval_freq))
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                if args.multiprocessing_distributed:
         | 
| 492 | 
            +
                    args.world_size = ngpus_per_node * args.world_size
         | 
| 493 | 
            +
                    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
         | 
| 494 | 
            +
                else:
         | 
| 495 | 
            +
                    main_worker(args.gpu, ngpus_per_node, args)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
             | 
| 498 | 
            +
            if __name__ == '__main__':
         | 
| 499 | 
            +
                main()
         | 
    	
        iebins/utils.py
    ADDED
    
    | @@ -0,0 +1,356 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import torch.distributed as dist
         | 
| 5 | 
            +
            from torch.utils.data import Sampler
         | 
| 6 | 
            +
            from torchvision import transforms
         | 
| 7 | 
            +
            import matplotlib.pyplot as plt
         | 
| 8 | 
            +
            import os, sys
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def convert_arg_line_to_args(arg_line):
         | 
| 15 | 
            +
                for arg in arg_line.split():
         | 
| 16 | 
            +
                    if not arg.strip():
         | 
| 17 | 
            +
                        continue
         | 
| 18 | 
            +
                    yield arg
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def block_print():
         | 
| 22 | 
            +
                sys.stdout = open(os.devnull, 'w')
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def enable_print():
         | 
| 26 | 
            +
                sys.stdout = sys.__stdout__
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def get_num_lines(file_path):
         | 
| 30 | 
            +
                f = open(file_path, 'r')
         | 
| 31 | 
            +
                lines = f.readlines()
         | 
| 32 | 
            +
                f.close()
         | 
| 33 | 
            +
                return len(lines)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def colorize(value, vmin=None, vmax=None, cmap='Greys'):
         | 
| 37 | 
            +
                value = value.cpu().numpy()[:, :, :]
         | 
| 38 | 
            +
                value = np.log10(value)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                vmin = value.min() if vmin is None else vmin
         | 
| 41 | 
            +
                vmax = value.max() if vmax is None else vmax
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if vmin != vmax:
         | 
| 44 | 
            +
                    value = (value - vmin) / (vmax - vmin)
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    value = value*0.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                cmapper = matplotlib.cm.get_cmap(cmap)
         | 
| 49 | 
            +
                value = cmapper(value, bytes=True)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                img = value[:, :, :3]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                return img.transpose((2, 0, 1))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def normalize_result(value, vmin=None, vmax=None):
         | 
| 57 | 
            +
                value = value.cpu().numpy()[0, :, :]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                vmin = value.min() if vmin is None else vmin
         | 
| 60 | 
            +
                vmax = value.max() if vmax is None else vmax
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if vmin != vmax:
         | 
| 63 | 
            +
                    value = (value - vmin) / (vmax - vmin)
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    value = value * 0.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                return np.expand_dims(value, 0)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            inv_normalize = transforms.Normalize(
         | 
| 71 | 
            +
                mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
         | 
| 72 | 
            +
                std=[1/0.229, 1/0.224, 1/0.225]
         | 
| 73 | 
            +
            )
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3']
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def compute_errors(gt, pred):
         | 
| 80 | 
            +
                thresh = np.maximum((gt / pred), (pred / gt))
         | 
| 81 | 
            +
                d1 = (thresh < 1.25).mean()
         | 
| 82 | 
            +
                d2 = (thresh < 1.25 ** 2).mean()
         | 
| 83 | 
            +
                d3 = (thresh < 1.25 ** 3).mean()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                rms = (gt - pred) ** 2
         | 
| 86 | 
            +
                rms = np.sqrt(rms.mean())
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                log_rms = (np.log(gt) - np.log(pred)) ** 2
         | 
| 89 | 
            +
                log_rms = np.sqrt(log_rms.mean())
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                abs_rel = np.mean(np.abs(gt - pred) / gt)
         | 
| 92 | 
            +
                sq_rel = np.mean(((gt - pred) ** 2) / gt)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                err = np.log(pred) - np.log(gt)
         | 
| 95 | 
            +
                silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                err = np.abs(np.log10(pred) - np.log10(gt))
         | 
| 98 | 
            +
                log10 = np.mean(err)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            class silog_loss(nn.Module):
         | 
| 104 | 
            +
                def __init__(self, variance_focus):
         | 
| 105 | 
            +
                    super(silog_loss, self).__init__()
         | 
| 106 | 
            +
                    self.variance_focus = variance_focus
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, depth_est, depth_gt, mask):
         | 
| 109 | 
            +
                    d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
         | 
| 110 | 
            +
                    return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def entropy_loss(preds, gt_label, mask):
         | 
| 114 | 
            +
                # preds: B, C, H, W
         | 
| 115 | 
            +
                # gt_label: B, H, W
         | 
| 116 | 
            +
                # mask: B, H, W
         | 
| 117 | 
            +
                mask = mask > 0.0 # B, H, W
         | 
| 118 | 
            +
                preds = preds.permute(0, 2, 3, 1) # B, H, W, C
         | 
| 119 | 
            +
                preds_mask = preds[mask] # N, C
         | 
| 120 | 
            +
                gt_label_mask = gt_label[mask] # N
         | 
| 121 | 
            +
                loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean')
         | 
| 122 | 
            +
                return loss
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def colormap(inputs, normalize=True, torch_transpose=True):
         | 
| 126 | 
            +
                if isinstance(inputs, torch.Tensor):
         | 
| 127 | 
            +
                    inputs = inputs.detach().cpu().numpy()
         | 
| 128 | 
            +
                _DEPTH_COLORMAP = plt.get_cmap('jet', 256)  # for plotting
         | 
| 129 | 
            +
                vis = inputs
         | 
| 130 | 
            +
                if normalize:
         | 
| 131 | 
            +
                    ma = float(vis.max())
         | 
| 132 | 
            +
                    mi = float(vis.min())
         | 
| 133 | 
            +
                    d = ma - mi if ma != mi else 1e5
         | 
| 134 | 
            +
                    vis = (vis - mi) / d
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                if vis.ndim == 4:
         | 
| 137 | 
            +
                    vis = vis.transpose([0, 2, 3, 1])
         | 
| 138 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 139 | 
            +
                    vis = vis[:, :, :, 0, :3]
         | 
| 140 | 
            +
                    if torch_transpose:
         | 
| 141 | 
            +
                        vis = vis.transpose(0, 3, 1, 2)
         | 
| 142 | 
            +
                elif vis.ndim == 3:
         | 
| 143 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 144 | 
            +
                    vis = vis[:, :, :, :3]
         | 
| 145 | 
            +
                    if torch_transpose:
         | 
| 146 | 
            +
                        vis = vis.transpose(0, 3, 1, 2)
         | 
| 147 | 
            +
                elif vis.ndim == 2:
         | 
| 148 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 149 | 
            +
                    vis = vis[..., :3]
         | 
| 150 | 
            +
                    if torch_transpose:
         | 
| 151 | 
            +
                        vis = vis.transpose(2, 0, 1)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                return vis[0,:,:,:]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def colormap_magma(inputs, normalize=True, torch_transpose=True):
         | 
| 157 | 
            +
                if isinstance(inputs, torch.Tensor):
         | 
| 158 | 
            +
                    inputs = inputs.detach().cpu().numpy()
         | 
| 159 | 
            +
                _DEPTH_COLORMAP = plt.get_cmap('magma', 256)  # for plotting
         | 
| 160 | 
            +
                vis = inputs
         | 
| 161 | 
            +
                if normalize:
         | 
| 162 | 
            +
                    ma = float(vis.max())
         | 
| 163 | 
            +
                    mi = float(vis.min())
         | 
| 164 | 
            +
                    d = ma - mi if ma != mi else 1e5
         | 
| 165 | 
            +
                    vis = (vis - mi) / d
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                if vis.ndim == 4:
         | 
| 168 | 
            +
                    vis = vis.transpose([0, 2, 3, 1])
         | 
| 169 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 170 | 
            +
                    vis = vis[:, :, :, 0, :3]
         | 
| 171 | 
            +
                    if torch_transpose:
         | 
| 172 | 
            +
                        vis = vis.transpose(0, 3, 1, 2)
         | 
| 173 | 
            +
                elif vis.ndim == 3:
         | 
| 174 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 175 | 
            +
                    vis = vis[:, :, :, :3]
         | 
| 176 | 
            +
                    if torch_transpose:
         | 
| 177 | 
            +
                        vis = vis.transpose(0, 3, 1, 2)
         | 
| 178 | 
            +
                elif vis.ndim == 2:
         | 
| 179 | 
            +
                    vis = _DEPTH_COLORMAP(vis)
         | 
| 180 | 
            +
                    vis = vis[..., :3]
         | 
| 181 | 
            +
                    if torch_transpose:
         | 
| 182 | 
            +
                        vis = vis.transpose(2, 0, 1)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                return vis[0,:,:,:]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def flip_lr(image):
         | 
| 188 | 
            +
                """
         | 
| 189 | 
            +
                Flip image horizontally
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                Parameters
         | 
| 192 | 
            +
                ----------
         | 
| 193 | 
            +
                image : torch.Tensor [B,3,H,W]
         | 
| 194 | 
            +
                    Image to be flipped
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                Returns
         | 
| 197 | 
            +
                -------
         | 
| 198 | 
            +
                image_flipped : torch.Tensor [B,3,H,W]
         | 
| 199 | 
            +
                    Flipped image
         | 
| 200 | 
            +
                """
         | 
| 201 | 
            +
                assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
         | 
| 202 | 
            +
                return torch.flip(image, [3])
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
         | 
| 206 | 
            +
                """
         | 
| 207 | 
            +
                Fuse inverse depth and flipped inverse depth maps
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                Parameters
         | 
| 210 | 
            +
                ----------
         | 
| 211 | 
            +
                inv_depth : torch.Tensor [B,1,H,W]
         | 
| 212 | 
            +
                    Inverse depth map
         | 
| 213 | 
            +
                inv_depth_hat : torch.Tensor [B,1,H,W]
         | 
| 214 | 
            +
                    Flipped inverse depth map produced from a flipped image
         | 
| 215 | 
            +
                method : str
         | 
| 216 | 
            +
                    Method that will be used to fuse the inverse depth maps
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                Returns
         | 
| 219 | 
            +
                -------
         | 
| 220 | 
            +
                fused_inv_depth : torch.Tensor [B,1,H,W]
         | 
| 221 | 
            +
                    Fused inverse depth map
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
                if method == 'mean':
         | 
| 224 | 
            +
                    return 0.5 * (inv_depth + inv_depth_hat)
         | 
| 225 | 
            +
                elif method == 'max':
         | 
| 226 | 
            +
                    return torch.max(inv_depth, inv_depth_hat)
         | 
| 227 | 
            +
                elif method == 'min':
         | 
| 228 | 
            +
                    return torch.min(inv_depth, inv_depth_hat)
         | 
| 229 | 
            +
                else:
         | 
| 230 | 
            +
                    raise ValueError('Unknown post-process method {}'.format(method))
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            def post_process_depth(depth, depth_flipped, method='mean'):
         | 
| 234 | 
            +
                """
         | 
| 235 | 
            +
                Post-process an inverse and flipped inverse depth map
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                Parameters
         | 
| 238 | 
            +
                ----------
         | 
| 239 | 
            +
                inv_depth : torch.Tensor [B,1,H,W]
         | 
| 240 | 
            +
                    Inverse depth map
         | 
| 241 | 
            +
                inv_depth_flipped : torch.Tensor [B,1,H,W]
         | 
| 242 | 
            +
                    Inverse depth map produced from a flipped image
         | 
| 243 | 
            +
                method : str
         | 
| 244 | 
            +
                    Method that will be used to fuse the inverse depth maps
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                Returns
         | 
| 247 | 
            +
                -------
         | 
| 248 | 
            +
                inv_depth_pp : torch.Tensor [B,1,H,W]
         | 
| 249 | 
            +
                    Post-processed inverse depth map
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
                B, C, H, W = depth.shape
         | 
| 252 | 
            +
                inv_depth_hat = flip_lr(depth_flipped)
         | 
| 253 | 
            +
                inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method)
         | 
| 254 | 
            +
                xs = torch.linspace(0., 1., W, device=depth.device,
         | 
| 255 | 
            +
                                    dtype=depth.dtype).repeat(B, C, H, 1)
         | 
| 256 | 
            +
                mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
         | 
| 257 | 
            +
                mask_hat = flip_lr(mask)
         | 
| 258 | 
            +
                return mask_hat * depth + mask * inv_depth_hat + \
         | 
| 259 | 
            +
                       (1.0 - mask - mask_hat) * inv_depth_fused
         | 
| 260 | 
            +
             | 
| 261 | 
            +
             | 
| 262 | 
            +
            class DistributedSamplerNoEvenlyDivisible(Sampler):
         | 
| 263 | 
            +
                """Sampler that restricts data loading to a subset of the dataset.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                It is especially useful in conjunction with
         | 
| 266 | 
            +
                :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
         | 
| 267 | 
            +
                process can pass a DistributedSampler instance as a DataLoader sampler,
         | 
| 268 | 
            +
                and load a subset of the original dataset that is exclusive to it.
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                .. note::
         | 
| 271 | 
            +
                    Dataset is assumed to be of constant size.
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                Arguments:
         | 
| 274 | 
            +
                    dataset: Dataset used for sampling.
         | 
| 275 | 
            +
                    num_replicas (optional): Number of processes participating in
         | 
| 276 | 
            +
                        distributed training.
         | 
| 277 | 
            +
                    rank (optional): Rank of the current process within num_replicas.
         | 
| 278 | 
            +
                    shuffle (optional): If true (default), sampler will shuffle the indices
         | 
| 279 | 
            +
                """
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
         | 
| 282 | 
            +
                    if num_replicas is None:
         | 
| 283 | 
            +
                        if not dist.is_available():
         | 
| 284 | 
            +
                            raise RuntimeError("Requires distributed package to be available")
         | 
| 285 | 
            +
                        num_replicas = dist.get_world_size()
         | 
| 286 | 
            +
                    if rank is None:
         | 
| 287 | 
            +
                        if not dist.is_available():
         | 
| 288 | 
            +
                            raise RuntimeError("Requires distributed package to be available")
         | 
| 289 | 
            +
                        rank = dist.get_rank()
         | 
| 290 | 
            +
                    self.dataset = dataset
         | 
| 291 | 
            +
                    self.num_replicas = num_replicas
         | 
| 292 | 
            +
                    self.rank = rank
         | 
| 293 | 
            +
                    self.epoch = 0
         | 
| 294 | 
            +
                    num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
         | 
| 295 | 
            +
                    rest = len(self.dataset) - num_samples * self.num_replicas
         | 
| 296 | 
            +
                    if self.rank < rest:
         | 
| 297 | 
            +
                        num_samples += 1
         | 
| 298 | 
            +
                    self.num_samples = num_samples
         | 
| 299 | 
            +
                    self.total_size = len(dataset)
         | 
| 300 | 
            +
                    # self.total_size = self.num_samples * self.num_replicas
         | 
| 301 | 
            +
                    self.shuffle = shuffle
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def __iter__(self):
         | 
| 304 | 
            +
                    # deterministically shuffle based on epoch
         | 
| 305 | 
            +
                    g = torch.Generator()
         | 
| 306 | 
            +
                    g.manual_seed(self.epoch)
         | 
| 307 | 
            +
                    if self.shuffle:
         | 
| 308 | 
            +
                        indices = torch.randperm(len(self.dataset), generator=g).tolist()
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        indices = list(range(len(self.dataset)))
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    # add extra samples to make it evenly divisible
         | 
| 313 | 
            +
                    # indices += indices[:(self.total_size - len(indices))]
         | 
| 314 | 
            +
                    # assert len(indices) == self.total_size
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    # subsample
         | 
| 317 | 
            +
                    indices = indices[self.rank:self.total_size:self.num_replicas]
         | 
| 318 | 
            +
                    self.num_samples = len(indices)
         | 
| 319 | 
            +
                    # assert len(indices) == self.num_samples
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    return iter(indices)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def __len__(self):
         | 
| 324 | 
            +
                    return self.num_samples
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                def set_epoch(self, epoch):
         | 
| 327 | 
            +
                    self.epoch = epoch
         | 
| 328 | 
            +
                
         | 
| 329 | 
            +
                
         | 
| 330 | 
            +
            class D_to_cloud(nn.Module):
         | 
| 331 | 
            +
                """Layer to transform depth into point cloud
         | 
| 332 | 
            +
                """
         | 
| 333 | 
            +
                def __init__(self, batch_size, height, width):
         | 
| 334 | 
            +
                    super(D_to_cloud, self).__init__()
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    self.batch_size = batch_size
         | 
| 337 | 
            +
                    self.height = height
         | 
| 338 | 
            +
                    self.width = width
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
         | 
| 341 | 
            +
                    self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W    
         | 
| 342 | 
            +
                    self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W  
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
         | 
| 345 | 
            +
                                             requires_grad=False) # B, 1, H, W
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    self.pix_coords = torch.unsqueeze(torch.stack(
         | 
| 348 | 
            +
                        [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L
         | 
| 349 | 
            +
                    self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L
         | 
| 350 | 
            +
                    self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def forward(self, depth, inv_K):
         | 
| 353 | 
            +
                    cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
         | 
| 354 | 
            +
                    cam_points = depth.view(self.batch_size, 1, -1) * cam_points
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    return cam_points.permute(0, 2, 1)
         | 
    	
        iebins/utils/transfrom.py
    ADDED
    
    | @@ -0,0 +1,250 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            from PIL import Image, ImageOps, ImageFilter
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torchvision import transforms
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
            import math
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
         | 
| 13 | 
            +
                """Rezise the sample to ensure the given size. Keeps aspect ratio.
         | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                    sample (dict): sample
         | 
| 16 | 
            +
                    size (tuple): image size
         | 
| 17 | 
            +
                Returns:
         | 
| 18 | 
            +
                    tuple: new size
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                shape = list(sample["disparity"].shape)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                if shape[0] >= size[0] and shape[1] >= size[1]:
         | 
| 23 | 
            +
                    return sample
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                scale = [0, 0]
         | 
| 26 | 
            +
                scale[0] = size[0] / shape[0]
         | 
| 27 | 
            +
                scale[1] = size[1] / shape[1]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                scale = max(scale)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                shape[0] = math.ceil(scale * shape[0])
         | 
| 32 | 
            +
                shape[1] = math.ceil(scale * shape[1])
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # resize
         | 
| 35 | 
            +
                sample["image"] = cv2.resize(
         | 
| 36 | 
            +
                    sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                sample["disparity"] = cv2.resize(
         | 
| 40 | 
            +
                    sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
                sample["mask"] = cv2.resize(
         | 
| 43 | 
            +
                    sample["mask"].astype(np.float32),
         | 
| 44 | 
            +
                    tuple(shape[::-1]),
         | 
| 45 | 
            +
                    interpolation=cv2.INTER_NEAREST,
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
                sample["mask"] = sample["mask"].astype(bool)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return tuple(shape)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class Resize(object):
         | 
| 53 | 
            +
                """Resize sample to given size (width, height).
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def __init__(
         | 
| 57 | 
            +
                    self,
         | 
| 58 | 
            +
                    width,
         | 
| 59 | 
            +
                    height,
         | 
| 60 | 
            +
                    resize_target=True,
         | 
| 61 | 
            +
                    keep_aspect_ratio=False,
         | 
| 62 | 
            +
                    ensure_multiple_of=1,
         | 
| 63 | 
            +
                    resize_method="lower_bound",
         | 
| 64 | 
            +
                    image_interpolation_method=cv2.INTER_AREA,
         | 
| 65 | 
            +
                ):
         | 
| 66 | 
            +
                    """Init.
         | 
| 67 | 
            +
                    Args:
         | 
| 68 | 
            +
                        width (int): desired output width
         | 
| 69 | 
            +
                        height (int): desired output height
         | 
| 70 | 
            +
                        resize_target (bool, optional):
         | 
| 71 | 
            +
                            True: Resize the full sample (image, mask, target).
         | 
| 72 | 
            +
                            False: Resize image only.
         | 
| 73 | 
            +
                            Defaults to True.
         | 
| 74 | 
            +
                        keep_aspect_ratio (bool, optional):
         | 
| 75 | 
            +
                            True: Keep the aspect ratio of the input sample.
         | 
| 76 | 
            +
                            Output sample might not have the given width and height, and
         | 
| 77 | 
            +
                            resize behaviour depends on the parameter 'resize_method'.
         | 
| 78 | 
            +
                            Defaults to False.
         | 
| 79 | 
            +
                        ensure_multiple_of (int, optional):
         | 
| 80 | 
            +
                            Output width and height is constrained to be multiple of this parameter.
         | 
| 81 | 
            +
                            Defaults to 1.
         | 
| 82 | 
            +
                        resize_method (str, optional):
         | 
| 83 | 
            +
                            "lower_bound": Output will be at least as large as the given size.
         | 
| 84 | 
            +
                            "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
         | 
| 85 | 
            +
                            "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
         | 
| 86 | 
            +
                            Defaults to "lower_bound".
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    self.__width = width
         | 
| 89 | 
            +
                    self.__height = height
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.__resize_target = resize_target
         | 
| 92 | 
            +
                    self.__keep_aspect_ratio = keep_aspect_ratio
         | 
| 93 | 
            +
                    self.__multiple_of = ensure_multiple_of
         | 
| 94 | 
            +
                    self.__resize_method = resize_method
         | 
| 95 | 
            +
                    self.__image_interpolation_method = image_interpolation_method
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
         | 
| 98 | 
            +
                    y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if max_val is not None and y > max_val:
         | 
| 101 | 
            +
                        y = (np.floor(x / self.__multiple_of)
         | 
| 102 | 
            +
                             * self.__multiple_of).astype(int)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if y < min_val:
         | 
| 105 | 
            +
                        y = (np.ceil(x / self.__multiple_of)
         | 
| 106 | 
            +
                             * self.__multiple_of).astype(int)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    return y
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def get_size(self, width, height):
         | 
| 111 | 
            +
                    # determine new height and width
         | 
| 112 | 
            +
                    scale_height = self.__height / height
         | 
| 113 | 
            +
                    scale_width = self.__width / width
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if self.__keep_aspect_ratio:
         | 
| 116 | 
            +
                        if self.__resize_method == "lower_bound":
         | 
| 117 | 
            +
                            # scale such that output size is lower bound
         | 
| 118 | 
            +
                            if scale_width > scale_height:
         | 
| 119 | 
            +
                                # fit width
         | 
| 120 | 
            +
                                scale_height = scale_width
         | 
| 121 | 
            +
                            else:
         | 
| 122 | 
            +
                                # fit height
         | 
| 123 | 
            +
                                scale_width = scale_height
         | 
| 124 | 
            +
                        elif self.__resize_method == "upper_bound":
         | 
| 125 | 
            +
                            # scale such that output size is upper bound
         | 
| 126 | 
            +
                            if scale_width < scale_height:
         | 
| 127 | 
            +
                                # fit width
         | 
| 128 | 
            +
                                scale_height = scale_width
         | 
| 129 | 
            +
                            else:
         | 
| 130 | 
            +
                                # fit height
         | 
| 131 | 
            +
                                scale_width = scale_height
         | 
| 132 | 
            +
                        elif self.__resize_method == "minimal":
         | 
| 133 | 
            +
                            # scale as least as possbile
         | 
| 134 | 
            +
                            if abs(1 - scale_width) < abs(1 - scale_height):
         | 
| 135 | 
            +
                                # fit width
         | 
| 136 | 
            +
                                scale_height = scale_width
         | 
| 137 | 
            +
                            else:
         | 
| 138 | 
            +
                                # fit height
         | 
| 139 | 
            +
                                scale_width = scale_height
         | 
| 140 | 
            +
                        else:
         | 
| 141 | 
            +
                            raise ValueError(
         | 
| 142 | 
            +
                                f"resize_method {self.__resize_method} not implemented"
         | 
| 143 | 
            +
                            )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.__resize_method == "lower_bound":
         | 
| 146 | 
            +
                        new_height = self.constrain_to_multiple_of(
         | 
| 147 | 
            +
                            scale_height * height, min_val=self.__height
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                        new_width = self.constrain_to_multiple_of(
         | 
| 150 | 
            +
                            scale_width * width, min_val=self.__width
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                    elif self.__resize_method == "upper_bound":
         | 
| 153 | 
            +
                        new_height = self.constrain_to_multiple_of(
         | 
| 154 | 
            +
                            scale_height * height, max_val=self.__height
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        new_width = self.constrain_to_multiple_of(
         | 
| 157 | 
            +
                            scale_width * width, max_val=self.__width
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                    elif self.__resize_method == "minimal":
         | 
| 160 | 
            +
                        new_height = self.constrain_to_multiple_of(scale_height * height)
         | 
| 161 | 
            +
                        new_width = self.constrain_to_multiple_of(scale_width * width)
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        raise ValueError(f"resize_method {
         | 
| 164 | 
            +
                                         self.__resize_method} not implemented")
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    return (new_width, new_height)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def __call__(self, sample):
         | 
| 169 | 
            +
                    width, height = self.get_size(
         | 
| 170 | 
            +
                        sample["image"].shape[1], sample["image"].shape[0]
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # resize sample
         | 
| 174 | 
            +
                    sample["image"] = cv2.resize(
         | 
| 175 | 
            +
                        sample["image"],
         | 
| 176 | 
            +
                        (width, height),
         | 
| 177 | 
            +
                        interpolation=self.__image_interpolation_method,
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    if self.__resize_target:
         | 
| 181 | 
            +
                        if "disparity" in sample:
         | 
| 182 | 
            +
                            sample["disparity"] = cv2.resize(
         | 
| 183 | 
            +
                                sample["disparity"],
         | 
| 184 | 
            +
                                (width, height),
         | 
| 185 | 
            +
                                interpolation=cv2.INTER_NEAREST,
         | 
| 186 | 
            +
                            )
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        if "depth" in sample:
         | 
| 189 | 
            +
                            sample["depth"] = cv2.resize(
         | 
| 190 | 
            +
                                sample["depth"], (width,
         | 
| 191 | 
            +
                                                  height), interpolation=cv2.INTER_NEAREST
         | 
| 192 | 
            +
                            )
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                        if "semseg_mask" in sample:
         | 
| 195 | 
            +
                            # sample["semseg_mask"] = cv2.resize(
         | 
| 196 | 
            +
                            #     sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
         | 
| 197 | 
            +
                            # )
         | 
| 198 | 
            +
                            sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[
         | 
| 199 | 
            +
                                                                  None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        if "mask" in sample:
         | 
| 202 | 
            +
                            sample["mask"] = cv2.resize(
         | 
| 203 | 
            +
                                sample["mask"].astype(np.float32),
         | 
| 204 | 
            +
                                (width, height),
         | 
| 205 | 
            +
                                interpolation=cv2.INTER_NEAREST,
         | 
| 206 | 
            +
                            )
         | 
| 207 | 
            +
                            # sample["mask"] = sample["mask"].astype(bool)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # print(sample['image'].shape, sample['depth'].shape)
         | 
| 210 | 
            +
                    return sample
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            class NormalizeImage(object):
         | 
| 214 | 
            +
                """Normlize image by given mean and std.
         | 
| 215 | 
            +
                """
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def __init__(self, mean, std):
         | 
| 218 | 
            +
                    self.__mean = mean
         | 
| 219 | 
            +
                    self.__std = std
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def __call__(self, sample):
         | 
| 222 | 
            +
                    sample["image"] = (sample["image"] - self.__mean) / self.__std
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    return sample
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            class PrepareForNet(object):
         | 
| 228 | 
            +
                """Prepare sample for usage as network input.
         | 
| 229 | 
            +
                """
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def __init__(self):
         | 
| 232 | 
            +
                    pass
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                def __call__(self, sample):
         | 
| 235 | 
            +
                    image = np.transpose(sample["image"], (2, 0, 1))
         | 
| 236 | 
            +
                    sample["image"] = np.ascontiguousarray(image).astype(np.float32)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if "mask" in sample:
         | 
| 239 | 
            +
                        sample["mask"] = sample["mask"].astype(np.float32)
         | 
| 240 | 
            +
                        sample["mask"] = np.ascontiguousarray(sample["mask"])
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    if "depth" in sample:
         | 
| 243 | 
            +
                        depth = sample["depth"].astype(np.float32)
         | 
| 244 | 
            +
                        sample["depth"] = np.ascontiguousarray(depth)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    if "semseg_mask" in sample:
         | 
| 247 | 
            +
                        sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
         | 
| 248 | 
            +
                        sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    return sample
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pytorch=1.10.0
         | 
| 2 | 
            +
            torchvision
         | 
| 3 | 
            +
            cudatoolkit=11.1
         | 
| 4 | 
            +
            matplotlib
         | 
| 5 | 
            +
            tqdm
         | 
| 6 | 
            +
            tensorboardX
         | 
| 7 | 
            +
            timm
         | 
| 8 | 
            +
            mmcv
         | 
| 9 | 
            +
            open3d
         | 
| 10 | 
            +
            gradio_imageslider
         | 
| 11 | 
            +
            torch
         | 
| 12 | 
            +
            opencv-python
         | 
