Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						b40a1f8
	
1
								Parent(s):
							
							91e5f9b
								
added unpaired low-light dataset
Browse files- enhance_me/commons.py +15 -0
- enhance_me/zero_dce/zero_dce.py +18 -3
- test.py +4 -0
    	
        enhance_me/commons.py
    CHANGED
    
    | @@ -61,3 +61,18 @@ def download_lol_dataset(): | |
| 61 | 
             
                test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
         | 
| 62 | 
             
                assert len(test_low_images) == len(test_enhanced_images)
         | 
| 63 | 
             
                return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 61 | 
             
                test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
         | 
| 62 | 
             
                assert len(test_low_images) == len(test_enhanced_images)
         | 
| 63 | 
             
                return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def download_unpaired_low_light_dataset():
         | 
| 67 | 
            +
                utils.get_file(
         | 
| 68 | 
            +
                    "low_light_dataset.zip",
         | 
| 69 | 
            +
                    "https://github.com/soumik12345/enhance-me/releases/download/v0.3/low_light_dataset.zip",
         | 
| 70 | 
            +
                    cache_dir="./",
         | 
| 71 | 
            +
                    cache_subdir="./datasets",
         | 
| 72 | 
            +
                    extract=True,
         | 
| 73 | 
            +
                )
         | 
| 74 | 
            +
                low_images = glob("./datasets/low_light_dataset/*.png")
         | 
| 75 | 
            +
                test_low_images = sorted(glob("./datasets/low_light_dataset/eval15/low/*"))
         | 
| 76 | 
            +
                test_enhanced_images = sorted(glob("./datasets/low_light_dataset/eval15/high/*"))
         | 
| 77 | 
            +
                assert len(test_low_images) == len(test_enhanced_images)
         | 
| 78 | 
            +
                return low_images, (test_low_images, test_enhanced_images)
         | 
    	
        enhance_me/zero_dce/zero_dce.py
    CHANGED
    
    | @@ -16,15 +16,25 @@ from .losses import ( | |
| 16 | 
             
                illumination_smoothness_loss,
         | 
| 17 | 
             
                SpatialConsistencyLoss,
         | 
| 18 | 
             
            )
         | 
| 19 | 
            -
            from ..commons import  | 
|  | |
|  | |
|  | |
|  | |
| 20 |  | 
| 21 |  | 
| 22 | 
             
            class ZeroDCE(Model):
         | 
| 23 | 
            -
                def __init__( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
                    super(ZeroDCE, self).__init__(**kwargs)
         | 
| 25 | 
             
                    self.experiment_name = experiment_name
         | 
| 26 | 
             
                    if use_mixed_precision:
         | 
| 27 | 
            -
                        policy = mixed_precision.Policy( | 
| 28 | 
             
                        mixed_precision.set_global_policy(policy)
         | 
| 29 | 
             
                    if wandb_api_key is not None:
         | 
| 30 | 
             
                        init_wandb("zero-dce", experiment_name, wandb_api_key)
         | 
| @@ -125,6 +135,11 @@ class ZeroDCE(Model): | |
| 125 | 
             
                ) -> None:
         | 
| 126 | 
             
                    if dataset_label == "lol":
         | 
| 127 | 
             
                        (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 128 | 
             
                    data_loader = UnpairedLowLightDataset(
         | 
| 129 | 
             
                        image_size,
         | 
| 130 | 
             
                        apply_resize,
         | 
|  | |
| 16 | 
             
                illumination_smoothness_loss,
         | 
| 17 | 
             
                SpatialConsistencyLoss,
         | 
| 18 | 
             
            )
         | 
| 19 | 
            +
            from ..commons import (
         | 
| 20 | 
            +
                download_lol_dataset,
         | 
| 21 | 
            +
                download_unpaired_low_light_dataset,
         | 
| 22 | 
            +
                init_wandb,
         | 
| 23 | 
            +
            )
         | 
| 24 |  | 
| 25 |  | 
| 26 | 
             
            class ZeroDCE(Model):
         | 
| 27 | 
            +
                def __init__(
         | 
| 28 | 
            +
                    self,
         | 
| 29 | 
            +
                    experiment_name=None,
         | 
| 30 | 
            +
                    wandb_api_key=None,
         | 
| 31 | 
            +
                    use_mixed_precision: bool = False,
         | 
| 32 | 
            +
                    **kwargs
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
             
                    super(ZeroDCE, self).__init__(**kwargs)
         | 
| 35 | 
             
                    self.experiment_name = experiment_name
         | 
| 36 | 
             
                    if use_mixed_precision:
         | 
| 37 | 
            +
                        policy = mixed_precision.Policy("mixed_float16")
         | 
| 38 | 
             
                        mixed_precision.set_global_policy(policy)
         | 
| 39 | 
             
                    if wandb_api_key is not None:
         | 
| 40 | 
             
                        init_wandb("zero-dce", experiment_name, wandb_api_key)
         | 
|  | |
| 135 | 
             
                ) -> None:
         | 
| 136 | 
             
                    if dataset_label == "lol":
         | 
| 137 | 
             
                        (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
         | 
| 138 | 
            +
                    elif dataset_label == "unpaired":
         | 
| 139 | 
            +
                        self.low_images, (
         | 
| 140 | 
            +
                            self.test_low_images,
         | 
| 141 | 
            +
                            _,
         | 
| 142 | 
            +
                        ) = download_unpaired_low_light_dataset()
         | 
| 143 | 
             
                    data_loader = UnpairedLowLightDataset(
         | 
| 144 | 
             
                        image_size,
         | 
| 145 | 
             
                        apply_resize,
         | 
    	
        test.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enhance_me.commons import download_unpaired_low_light_dataset
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            download_unpaired_low_light_dataset()
         | 
