Spaces:
Runtime error
Runtime error
Commit
·
4cf5013
1
Parent(s):
d7c1491
added resize option
Browse files
enhance_me/zero_dce/dataloader.py
CHANGED
|
@@ -9,20 +9,31 @@ class UnpairedLowLightDataset:
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
image_size: int = 256,
|
|
|
|
| 12 |
apply_random_horizontal_flip: bool = True,
|
| 13 |
apply_random_vertical_flip: bool = True,
|
| 14 |
apply_random_rotation: bool = True,
|
| 15 |
) -> None:
|
| 16 |
self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
|
|
|
|
|
|
|
| 17 |
self.apply_random_horizontal_flip = apply_random_horizontal_flip
|
| 18 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
| 19 |
self.apply_random_rotation = apply_random_rotation
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
| 22 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
| 23 |
dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 24 |
-
dataset =
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
if is_train:
|
| 28 |
dataset = (
|
|
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
image_size: int = 256,
|
| 12 |
+
apply_resize: bool = False,
|
| 13 |
apply_random_horizontal_flip: bool = True,
|
| 14 |
apply_random_vertical_flip: bool = True,
|
| 15 |
apply_random_rotation: bool = True,
|
| 16 |
) -> None:
|
| 17 |
self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
|
| 18 |
+
self.image_size = image_size
|
| 19 |
+
self.apply_resize = apply_resize
|
| 20 |
self.apply_random_horizontal_flip = apply_random_horizontal_flip
|
| 21 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
| 22 |
self.apply_random_rotation = apply_random_rotation
|
| 23 |
|
| 24 |
+
def _resize(self, image):
|
| 25 |
+
return tf.image.resize(image, (self.image_size, self.image_size))
|
| 26 |
+
|
| 27 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
| 28 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
| 29 |
dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 30 |
+
dataset = (
|
| 31 |
+
dataset.map(
|
| 32 |
+
self.augmentation_factory.random_crop,
|
| 33 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 34 |
+
)
|
| 35 |
+
if not self.apply_resize
|
| 36 |
+
else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE)
|
| 37 |
)
|
| 38 |
if is_train:
|
| 39 |
dataset = (
|
enhance_me/zero_dce/zero_dce.py
CHANGED
|
@@ -113,6 +113,7 @@ class ZeroDCE(Model):
|
|
| 113 |
self,
|
| 114 |
image_size: int = 256,
|
| 115 |
dataset_label: str = "lol",
|
|
|
|
| 116 |
apply_random_horizontal_flip: bool = True,
|
| 117 |
apply_random_vertical_flip: bool = True,
|
| 118 |
apply_random_rotation: bool = True,
|
|
@@ -123,6 +124,7 @@ class ZeroDCE(Model):
|
|
| 123 |
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
| 124 |
data_loader = UnpairedLowLightDataset(
|
| 125 |
image_size,
|
|
|
|
| 126 |
apply_random_horizontal_flip,
|
| 127 |
apply_random_vertical_flip,
|
| 128 |
apply_random_rotation,
|
|
@@ -130,7 +132,7 @@ class ZeroDCE(Model):
|
|
| 130 |
self.train_dataset, self.val_dataset = data_loader.get_datasets(
|
| 131 |
self.low_images, val_split, batch_size
|
| 132 |
)
|
| 133 |
-
|
| 134 |
def train(self, epochs: int):
|
| 135 |
log_dir = os.path.join(
|
| 136 |
self.experiment_name,
|
|
@@ -148,7 +150,7 @@ class ZeroDCE(Model):
|
|
| 148 |
callbacks=callbacks,
|
| 149 |
)
|
| 150 |
return history
|
| 151 |
-
|
| 152 |
def infer(self, original_image):
|
| 153 |
image = keras.preprocessing.image.img_to_array(original_image)
|
| 154 |
image = image.astype("float32") / 255.0
|
|
@@ -157,7 +159,7 @@ class ZeroDCE(Model):
|
|
| 157 |
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
| 158 |
output_image = Image.fromarray(output_image.numpy())
|
| 159 |
return output_image
|
| 160 |
-
|
| 161 |
def infer_from_file(self, original_image_file: str):
|
| 162 |
original_image = Image.open(original_image_file)
|
| 163 |
return self.infer(original_image)
|
|
|
|
| 113 |
self,
|
| 114 |
image_size: int = 256,
|
| 115 |
dataset_label: str = "lol",
|
| 116 |
+
apply_resize: bool = False,
|
| 117 |
apply_random_horizontal_flip: bool = True,
|
| 118 |
apply_random_vertical_flip: bool = True,
|
| 119 |
apply_random_rotation: bool = True,
|
|
|
|
| 124 |
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
| 125 |
data_loader = UnpairedLowLightDataset(
|
| 126 |
image_size,
|
| 127 |
+
apply_resize,
|
| 128 |
apply_random_horizontal_flip,
|
| 129 |
apply_random_vertical_flip,
|
| 130 |
apply_random_rotation,
|
|
|
|
| 132 |
self.train_dataset, self.val_dataset = data_loader.get_datasets(
|
| 133 |
self.low_images, val_split, batch_size
|
| 134 |
)
|
| 135 |
+
|
| 136 |
def train(self, epochs: int):
|
| 137 |
log_dir = os.path.join(
|
| 138 |
self.experiment_name,
|
|
|
|
| 150 |
callbacks=callbacks,
|
| 151 |
)
|
| 152 |
return history
|
| 153 |
+
|
| 154 |
def infer(self, original_image):
|
| 155 |
image = keras.preprocessing.image.img_to_array(original_image)
|
| 156 |
image = image.astype("float32") / 255.0
|
|
|
|
| 159 |
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
| 160 |
output_image = Image.fromarray(output_image.numpy())
|
| 161 |
return output_image
|
| 162 |
+
|
| 163 |
def infer_from_file(self, original_image_file: str):
|
| 164 |
original_image = Image.open(original_image_file)
|
| 165 |
return self.infer(original_image)
|
notebooks/enhance_me_train.ipynb
CHANGED
|
@@ -190,6 +190,7 @@
|
|
| 190 |
"experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
|
| 191 |
"image_size = 256 # @param {type:\"integer\"}\n",
|
| 192 |
"dataset_label = \"lol\" # @param [\"lol\"]\n",
|
|
|
|
| 193 |
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
| 194 |
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
| 195 |
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
|
@@ -223,6 +224,7 @@
|
|
| 223 |
"zero_dce.build_datasets(\n",
|
| 224 |
" image_size=image_size,\n",
|
| 225 |
" dataset_label=dataset_label,\n",
|
|
|
|
| 226 |
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
| 227 |
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
| 228 |
" apply_random_rotation=apply_random_rotation,\n",
|
|
|
|
| 190 |
"experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
|
| 191 |
"image_size = 256 # @param {type:\"integer\"}\n",
|
| 192 |
"dataset_label = \"lol\" # @param [\"lol\"]\n",
|
| 193 |
+
"apply_resize = False # @param {type:\"boolean\"}\n",
|
| 194 |
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
| 195 |
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
| 196 |
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
|
|
|
| 224 |
"zero_dce.build_datasets(\n",
|
| 225 |
" image_size=image_size,\n",
|
| 226 |
" dataset_label=dataset_label,\n",
|
| 227 |
+
" apply_resize=apply_resize,\n",
|
| 228 |
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
| 229 |
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
| 230 |
" apply_random_rotation=apply_random_rotation,\n",
|